Upload
others
View
8
Download
0
Embed Size (px)
Citation preview
Automated Inference with Adaptive Batches
Soham De
Joint work with Abhay Yadav, David Jacobs, Tom Goldstein
University of Maryland
OVERVIEWMost machine learning models use SGD for training
BUT…noisy gradients & large variance
Hard to automate stepsize selection & stopping conditions
OVERVIEWMost machine learning models use SGD for training
BUT…noisy gradients & large variance
Hard to automate stepsize selection & stopping conditions
In this work: Big Batch SGD
Adaptively grow batch size based on amount of noise in the gradients
Easy automated stepsize selection
MOST MODEL FITTING PROBLEMS LOOK LIKE THIS
min `(x) :=1
N
NX
i=1
f(x, zi)
MOST MODEL FITTING PROBLEMS LOOK LIKE THIS
SVM
logistic regression
neural netsmatrix factorization
blah blah blah
We also consider the more general problem:
min `(x) := Ez⇠p[f(x, z)]
min `(x) :=1
N
NX
i=1
f(x, zi)
Applications
SGDcompute gradientselect data update
xt+1 = xt � ↵tgtgt =1
N
NX
i=1
rf(xt, zi)
SGDcompute gradientselect data update
xt+1 = xt � ↵tgtgt ⇡ rf(xt, z12)
SGD
far from optimalsolution improves
compute gradientselect data update
xt+1 = xt � ↵tgtgt ⇡ rf(xt, z8)
SGDcompute gradientselect data update
gt ⇡ rf(xt, z8) xt+1 = xt � ↵tgt
SGDcompute gradientselect data update
close to optimalsolution gets worse
gt ⇡ rf(xt, z8) xt+1 = xt � ↵tgt
SGD
Error must decrease as we approach solution
select data compute gradient update
gt ⇡ rf(xt, z8) xt+1 = xt � ↵tgt
SGD
Error must decrease as we approach solution
classical solution
shrink stepsize
select data compute gradient update
gt ⇡ rf(xt, z8) xt+1 = xt � ↵tgt
limt!1
↵t = 0
SGD
Error must decrease as we approach solution
classical solution
shrink stepsize hard to pick stepsize schedule
select data compute gradient update
gt ⇡ rf(xt, z8) xt+1 = xt � ↵tgt
limt!1
↵t = 0
USE BIGGER BATCHES?compute gradientselect data update
gt ⇡ rf(xt, z8) xt+1 = xt � ↵tgt
close to optimalsolution gets worse
USE BIGGER BATCHES?compute gradientselect data update
xt+1 = xt � ↵tgt
} close to optimalsolution gets worse
USE BIGGER BATCHES?compute gradientselect data update
xt+1 = xt � ↵tgtgt ⇡ r`B(xt)
gradient on batch B} close to optimalsolution gets worse
USE BIGGER BATCHES?compute gradientselect data update
xt+1 = xt � ↵tgtgt ⇡ r`B(xt)
gradient on batch B} close to optimalsolution gets worse
USE BIGGER BATCHES?compute gradientselect data update
lower variance solution gets better
xt+1 = xt � ↵tgtgt ⇡ r`B(xt)
gradient on batch B}
GROWING BATCHESregime 1: far from optimal regime 2: close to optimal
GROWING BATCHESregime 1: far from optimal regime 2: close to optimal
Noisy gradients improve solution
Large batches do unnecessary work
+Get stuck in local minima
GROWING BATCHESregime 1: far from optimal regime 2: close to optimal
Noisy gradients improve solution
small batches work well!
Large batches do unnecessary work
+Get stuck in local minima
GROWING BATCHESregime 1: far from optimal regime 2: close to optimal
Noisy gradients improve solution
small batches work well!
Large batches do unnecessary work
+Get stuck in local minima
Noisy gradients with high variance worsen solution
Small batches require stepsize decay (hard to tune)
GROWING BATCHESregime 1: far from optimal regime 2: close to optimal
Noisy gradients improve solution
small batches work well!
Large batches do unnecessary work
+Get stuck in local minima
Noisy gradients with high variance worsen solution
Small batches require stepsize decay (hard to tune)
large batches work well!
GROWING BATCHESregime 1: far from optimal regime 2: close to optimal
Noisy gradients improve solution
small batches work well!
Large batches do unnecessary work
+Get stuck in local minima
Noisy gradients with high variance worsen solution
Small batches require stepsize decay (hard to tune)
large batches work well!
Adaptively Grow Batches Over Time!
PRELIMINARIES
kr`B(x)�r`(x)k2 < kr`B(x)k2.
Standard result in stochastic optimization:
Lemma A sufficient condition for to be a descent direction is:�r`B(x)
PRELIMINARIES
kr`B(x)�r`(x)k2 < kr`B(x)k2.
Standard result in stochastic optimization:
Lemma A sufficient condition for to be a descent direction is:�r`B(x)
error approximate gradient
PRELIMINARIES
kr`B(x)�r`(x)k2 < kr`B(x)k2.
Standard result in stochastic optimization:
Lemma A sufficient condition for to be a descent direction is:�r`B(x)
If error is small relative to gradient : descent direction
error approximate gradient
PRELIMINARIES
kr`B(x)�r`(x)k2 < kr`B(x)k2.
Standard result in stochastic optimization:
Lemma A sufficient condition for to be a descent direction is:�r`B(x)
error approximate gradient
How big is this error?
How large does the batch need to be to guarantee this?
ERROR BOUNDTheorem Assume f has Lz - Lipschitz dependence on data z.Then, expected error is uniformly bounded by:
EBkr`B(x)�r`(x)k2 = TrVarB(r`B(x))
1|B| · 4L2z TrVarz(z)
ERROR BOUNDTheorem Assume f has Lz - Lipschitz dependence on data z.Then, expected error is uniformly bounded by:
EBkr`B(x)�r`(x)k2 = TrVarB(r`B(x))expected error batch gradient variance
1|B| · 4L2z TrVarz(z)
ERROR BOUNDTheorem Assume f has Lz - Lipschitz dependence on data z.Then, expected error is uniformly bounded by:
EBkr`B(x)�r`(x)k2 = TrVarB(r`B(x))expected error batch gradient variance
1|B| · 4L2z TrVarz(z)
data variancebatch size
ERROR BOUNDTheorem Assume f has Lz - Lipschitz dependence on data z.Then, expected error is uniformly bounded by:
EBkr`B(x)�r`(x)k2 = TrVarB(r`B(x))expected error batch gradient variance
1|B| · 4L2z TrVarz(z)
data variancebatch size
Expected Error(Gradient Variance)
< Data Variance
Higher Batch Size Lower Gradient Variance
ERROR BOUND
From the previous Lemma and Theorem:We expect to be a descent direction reasonably often provided:
�r`B(x)
✓
2Ekr`Bt(xt)k2 �1
|Bt|TrVarzrf(xt, z)
where ✓ 2 (0, 1)
ERROR BOUND
From the previous Lemma and Theorem:We expect to be a descent direction reasonably often provided:
�r`B(x)
✓
2Ekr`Bt(xt)k2 �1
|Bt|TrVarzrf(xt, z)
where ✓ 2 (0, 1)
We use this observation in Big Batch SGD
BIG BATCH SGD
} pick batch B
BIG BATCH SGD
estimate size of error using variance of batch gradients
} pick batch B
BIG BATCH SGD
if signal > noise (gradient) (variance)
}
update
pick batch B
BIG BATCH SGD
if signal > noise (gradient) (variance)
}
update
pick batch B
BIG BATCH SGD
}
if signal > noise (gradient) (variance)
updatepick batch B
BIG BATCH SGD
}
if signal > noise (gradient) (variance)
updatepick batch B
BIG BATCH SGD
}
if signal > noise (gradient) (variance)
updatepick batch B
BIG BATCH SGD
}
if signal > noise (gradient) (variance)
update
pick batch B
BIG BATCH SGD
}
if signal > noise (gradient) (variance)
update
pick batch B
BIG BATCH SGD
}
otherwiseincrease batch size
if signal > noise (gradient) (variance)
update
pick batch B
BIG BATCH SGD
}otherwise
increase batch size
if signal > noise (gradient) (variance)
update
pick batch B
BIG BATCH SGD
}otherwise
increase batch size
if signal > noise (gradient) (variance)
update
pick batch B
BIG BATCH SGD
}otherwise
increase batch size
if signal > noise (gradient) (variance)
update
pick batch B
BIG BATCH SGD
}otherwise
increase batch size
if signal > noise (gradient) (variance)
update
pick batch B
BIG BATCH SGDOn each iteration t
BIG BATCH SGDOn each iteration t
• Estimate size of gradient error by computing variance
BIG BATCH SGDOn each iteration t
• Pick batch B large enough such that
✓
2Ekr`Bt(xt)k2 �1
|Bt|TrVarzrf(xt, z)
where ✓ 2 (0, 1)
• Estimate size of gradient error by computing variance
BIG BATCH SGDOn each iteration t
• Pick batch B large enough such that
✓
2Ekr`Bt(xt)k2 �1
|Bt|TrVarzrf(xt, z)
• Choose stepsize ↵t
where ✓ 2 (0, 1)
• Estimate size of gradient error by computing variance
BIG BATCH SGDOn each iteration t
• Pick batch B large enough such that
✓
2Ekr`Bt(xt)k2 �1
|Bt|TrVarzrf(xt, z)
• Choose stepsize ↵t• Update
xt+1 = xt � ↵tr`Bt(xt)
where ✓ 2 (0, 1)
• Estimate size of gradient error by computing variance
BIG BATCH SGDOn each iteration t
• Pick batch B large enough such that
✓
2Ekr`Bt(xt)k2 �1
|Bt|TrVarzrf(xt, z)
• Choose stepsize ↵t• Update
xt+1 = xt � ↵tr`Bt(xt)
where ✓ 2 (0, 1)extra step to SGD
• Estimate size of gradient error by computing variance
BIG BATCH SGDOn each iteration t
• Pick batch B large enough such that
✓
2Ekr`Bt(xt)k2 �1
|Bt|TrVarzrf(xt, z)
• Choose stepsize ↵t• Update
xt+1 = xt � ↵tr`Bt(xt)
where ✓ 2 (0, 1)
Can be estimated using the batch B
• Estimate size of gradient error by computing variance
CONVERGENCEAssumption: has L-Lipschitz gradients`
`(x) `(y) +r`(y)T (x� y) + L2kx� yk2
Assumption: satisfies the Polyak-Lojasiewicz (PL) Inequality`kr`(x)k2 � 2µ(`(x)� `(x?))
CONVERGENCEAssumption: has L-Lipschitz gradients`
`(x) `(y) +r`(y)T (x� y) + L2kx� yk2
Assumption: satisfies the Polyak-Lojasiewicz (PL) Inequality`kr`(x)k2 � 2µ(`(x)� `(x?))
Theorem Big Batch SGD converges linearly:
↵ =1
�L� =
✓2 + (1� ✓)2
(1� ✓)2
E[`(xt+1)� `(x?)] �1� µ
�L
�· E[`(xt)� `(x?)]
with optimal step size where
CONVERGENCE RESULTSTheorem Big Batch SGD converges linearly:
↵ =1
�L� =
✓2 + (1� ✓)2
(1� ✓)2
E[`(xt+1)� `(x?)] �1� µ
�L
�· E[`(xt)� `(x?)]
with optimal step size where
Per-iteration convergence rate
CONVERGENCE RESULTSTheorem Big Batch SGD converges linearly:
↵ =1
�L� =
✓2 + (1� ✓)2
(1� ✓)2
E[`(xt+1)� `(x?)] �1� µ
�L
�· E[`(xt)� `(x?)]
with optimal step size where
Per-iteration convergence rate
We show: Error < ✏|B| > O(1/✏)optimal convergence in the infinite data case
ADVANTAGES
ADVANTAGES
• Controlling noise enables automated stepsize selection
ADVANTAGES
• Controlling noise enables automated stepsize selection• Backtracking line search works well!
ADVANTAGES
• Controlling noise enables automated stepsize selection• Backtracking line search works well!• Stepsize schemes using curvature estimates also work!
ADVANTAGES
• Controlling noise enables automated stepsize selection• Backtracking line search works well!• Stepsize schemes using curvature estimates also work!
• More accurate gradients enable automatic stopping conditions
ADVANTAGES
• Controlling noise enables automated stepsize selection• Backtracking line search works well!• Stepsize schemes using curvature estimates also work!
• More accurate gradients enable automatic stopping conditions
• Bigger batches are better in parallel/distributed settings
BACKTRACKING LINE SEARCH
Measures sufficient decrease condition of objective function
For regular (deterministic) gradient descent:
`(xt+1) `(xt) � c↵tkr`(xt)k2
Also referred to as Armijo Line Search
BACKTRACKING LINE SEARCH
Measures sufficient decrease condition of objective function
For regular (deterministic) gradient descent:
`(xt+1) `(xt) � c↵tkr`(xt)k2
new objective current objective
sufficient decrease
Also referred to as Armijo Line Search
BACKTRACKING LINE SEARCH
Measures sufficient decrease condition of objective function
For regular (deterministic) gradient descent:
`(xt+1) `(xt) � c↵tkr`(xt)k2
new objective current objective
sufficient decrease
If this fails, decrease stepsize and check again
Also referred to as Armijo Line Search
BACKTRACKING WITH SGDMeasures sufficient decrease condition on individual functions
BACKTRACKING WITH SGD
Moves to the optimum of individual functions, not the global average
Measures sufficient decrease condition on individual functions
BACKTRACKING WITH SGD
Moves to the optimum of individual functions, not the global average
Measures sufficient decrease condition on individual functions
Big Batch SGD gets better estimate of the approximate decrease of original objective
BACKTRACKING WORKS WITH BIG BATCH SGD
`B(xt+1) `B(xt)� c↵tkr`B(xt)k2Decrease stepsize until:
Measures a condition of sufficient decrease using batch B
BACKTRACKING WORKS WITH BIG BATCH SGD
`B(xt+1) `B(xt)� c↵tkr`B(xt)k2Decrease stepsize until:
Measures a condition of sufficient decrease using batch B
Theorem Big Batch SGD with backtracking line search converges linearly:
with initial step size set large enough s.t.
E[`(xt+1)� `(x?)] ⇣1� cµ
�L
⌘E[`(xt)� `(x?)]
↵0 �1
2�L
BACKTRACKING WORKS WITH BIG BATCH SGD
`B(xt+1) `B(xt)� c↵tkr`B(xt)k2Decrease stepsize until:
Measures a condition of sufficient decrease using batch B
Theorem Big Batch SGD with backtracking line search converges linearly:
with initial step size set large enough s.t.
E[`(xt+1)� `(x?)] ⇣1� cµ
�L
⌘E[`(xt)� `(x?)]
↵0 �1
2�L
We also derive optimal stepsizes for Big Batch SGD using Barzilai-Borwein (BB) curvature estimates, with provable guarantees.
Check paper for details.
CONVEX EXPERIMENTSModel: Ridge Regression & Logistic Regression
BBS: Big Batch SGDProposed methods: Blue (Fixed stepsize), Red (Backtracking), Green (Optimal stepsizes using curvature estimates) curves
CONVEX EXPERIMENTSModel: Ridge Regression & Logistic Regression
BBS: Big Batch SGDProposed methods: Blue (Fixed stepsize), Red (Backtracking), Green (Optimal stepsizes using curvature estimates) curves
Big Batch SGD (BBS) with Fixed Step Size
CONVEX EXPERIMENTSModel: Ridge Regression & Logistic Regression
BBS: Big Batch SGDProposed methods: Blue (Fixed stepsize), Red (Backtracking), Green (Optimal stepsizes using curvature estimates) curves
BBS with Backtracking (Armijo) Line Search
CONVEX EXPERIMENTSModel: Ridge Regression & Logistic Regression
BBS: Big Batch SGDProposed methods: Blue (Fixed stepsize), Red (Backtracking), Green (Optimal stepsizes using curvature estimates) curves
BBS with Stepsizes using BB curvatures
CONVEX EXPERIMENTSModel: Ridge Regression & Logistic Regression
BBS: Big Batch SGDProposed methods: Blue (Fixed stepsize), Red (Backtracking), Green (Optimal stepsizes using curvature estimates) curves
CONVEX EXPERIMENTSBatch Size Increase
BBS: Big Batch SGDProposed methods: Blue (Fixed stepsize), Red (Backtracking), Green (Optimal stepsizes using curvature estamates) curves
CONVEX EXPERIMENTSStepsizes Used
BBS: Big Batch SGDProposed methods: Blue (Fixed stepsize), Red (Backtracking), Green (Optimal stepsizes using curvature estamates) curves
4-LAYER CNN
0 10 20 30 40
Number of epochs
50
55
60
65
70
75
80
Accuracy
Mean class accuracy (test set)
0 10 20 30 40
Number of epochs
83
84
85
86
87
88
89
90
Accuracy
Mean class accuracy (test set)
CIFAR-10 (left) & SVHN (right)
Automated Inference with Adaptive Batches
0 10 20 30 40Number of epochs
50
60
70
80
90
100
Accuracy
Mean class accuracy (train set)
AdadeltaBB+AdadeltaSGD+Mom (Fine Tuned)SGD+Mom (Fixed LR)BBS+Mom (Fixed LR)BBS+Mom+Armijo
0 10 20 30 40Number of epochs
86
88
90
92
94
96
98
100
Accuracy
Mean class accuracy (train set)
AdadeltaBB+AdadeltaSGD+Mom (Fine Tuned)SGD+Mom (Fixed LR)BBS+Mom (Fixed LR)BBS+Mom+Armijo
0 10 20 30 40Number of epochs
98.4
98.6
98.8
99
99.2
99.4
99.6
99.8
100
Accuracy
Mean class accuracy (train set)
AdadeltaBB+AdadeltaSGD+Mom (Fine Tuned)SGD+Mom (Fixed LR)BBS+Mom (Fixed LR)BBS+Mom+Armijo
0 10 20 30 40
Number of epochs
50
55
60
65
70
75
80
Accuracy
Mean class accuracy (test set)
0 10 20 30 40
Number of epochs
83
84
85
86
87
88
89
90
Accuracy
Mean class accuracy (test set)
0 10 20 30 40
Number of epochs
98
98.5
99
Accuracy
Mean class accuracy (test set)
Figure 2: Neural Network Experiments. The three columns from left to right correspond to results for CIFAR-10, SVHN,and MNIST, respectively. The top row presents classification accuracies on the training set, while the bottom row presentsclassification accuracies on the test set.
1998) (ConvNet) to classify three benchmark imagedatasets: CIFAR-10 (Krizhevsky and Hinton, 2009),SVHN (Netzer et al., 2011), and MNIST (LeCun et al.,1998). Our ConvNet is composed of 4 layers, excludingthe input layer, with over 4.3 million weights. To com-pare against fine-tuned SGD, we used a comprehensivegrid search on the stepsize schedule to identify the op-timal schedule. Fixed stepsize methods use the defaultdecay rule of the Torch library: ↵
t
= ↵0/(1 + 10�7t),where ↵0 was chosen to be the stepsize used in the finetuned experiments. We also tune the hyper-parameter⇢ in the Adadelta algorithm. Details of the ConvNetand exact hyper-parameters used for training are pre-sented in the supplemental.
We plot the accuracy on the train and test set vs thenumber of epochs (full passes through the dataset) inFigure 2. We notice that the big batch SGD with back-tracking performs better than both Adadelta and SGD(Fixed LR) in terms of both train and test error. Bigbatch SGD even performs comparably to fine tunedSGD but without the trouble of fine tuning. This is in-teresting because most state-of-the-art deep networks(like AlexNet (Krizhevsky et al., 2012), VGG Net (Si-monyan and Zisserman, 2014), ResNets (He et al.,2016)) were trained by their creators using standardSGD with momentum, and training parameters weretuned over long periods of time (sometimes months).Finally, we note that the big batch AdaDelta performsconsistently better than plain AdaDelta on both largescale problems (SVHN and CIFAR-10), and perfor-mance is nearly identical on the small-scale MNISTproblem.
7 Conclusion
We analyzed and studied the behavior of alternativeSGD methods in which the batch size increases overtime. Unlike classical SGD methods, in which stochas-tic gradients quickly become swamped with noise,these “big batch” methods maintain a nearly constantsignal to noise ratio of the approximate gradient. As aresult, big batch methods are able to adaptively adjustbatch sizes without user oversight. The proposed au-tomated methods are shown to be empirically compa-rable or better performing than other standard meth-ods, but without requiring an expert user to chooselearning rates and decay parameters.
Acknowledgements
This work was supported by the US O�ce of Naval Re-search (N00014-17-1-2078), and the National ScienceFoundation (CCF-1535902 and IIS-1526234). A. Ya-dav and D. Jacobs were supported by the O�ce of theDirector of National Intelligence (ODNI), IntelligenceAdvanced Research Projects Activity (IARPA), viaIARPA R&D Contract No. 2014-14071600012. Theviews and conclusions contained herein are those ofthe authors and should not be interpreted as necessar-ily representing the o�cial policies or endorsements,either expressed or implied, of the ODNI, IARPA, orthe U.S. Government. The U.S. Government is autho-rized to reproduce and distribute reprints for Govern-mental purposes notwithstanding any copyright anno-tation thereon.
Proposed Methods
Big Batch SGD canalso be used with AdaGrad, AdaDelta…
TAKEAWAYS
Adaptively grows batch size over time to maintain a nearly constant signal-to-noise ratio in the gradients
Better control of the noise makes it easy to automate
Adaptive stepsize methods work well with this method
Better for parallel/distributed settings
We introduce: Big Batch SGD
THANKS!Feel free to get in touch!
Extended version on arXiv: “Big Batch SGD: Automated Inference using Adaptive Batch Sizes”https://arxiv.org/abs/1610.05792
email: [email protected] website: https://cs.umd.edu/~sohamde/
Soham De Tom GoldsteinDavid JacobsAbhay Yadav
https://arxiv.org/abs/1610.05792mailto:[email protected]://cs.umd.edu/~sohamde/