Scaling Distributed Machine Learningmuli/file/mu_defense_animation.pdf · Hadoop/Spark BSP data ......

Preview:

Citation preview

with System and Algorithm Co-design

Mu Li Thesis Defense

CSD, CMU Feb 2nd, 2017

Scaling Distributed Machine Learning

2

minw

nX

i=1

fi(w)

2

minw

nX

i=1

fi(w)

Large-scale problems

2

minw

nX

i=1

fi(w)

Large-scale problems

✧ Distributed systems ✧ Large scale optimization methods

Large Scale Machine Learning

✧ Machine learning learns from data ✧ More data ✓ better accuracy ✓ can use more complex models

3

Data size

Accu

racy

More complex models

Ads Click Prediction

✧ Predict if an ad will be clicked ✧ Each ad impression is an example ✧ Logistic regression ✓ Single machine processes 1 million

examples per second

4

Ads Click Prediction

✧ Predict if an ad will be clicked ✧ Each ad impression is an example ✧ Logistic regression ✓ Single machine processes 1 million

examples per second ✧ A typical industrial size problem has ✓ 100 billion examples ✓ 10 billion unique features

5

0

175

350

525

700

Year

2010 2011 2012 2013 2014

Training data size

(TB)

6

✧ Recognize the object in an image ✧ Convolutional neural network ✧ A state-of-the-art network ✓ Hundreds of layers ✓ Billions of floating-point operation for

processing a single image

Image Recognition

7

✧ Distribute workload among many machines

✧ Widely available thanks to cloud providers (AWS, GCP, Azure) machine

machine

machine

machine

switch switch

switch switch

Distributed Computing for Large Scale Problems

8

✧ Distribute workload among many machines

✧ Widely available thanks to cloud providers (AWS, GCP, Azure)

✧ Challenges ✓ Limited communication bandwidth (10x

less than memory bandwidth) ✓ Large synchronization cost (1ms latency) ✓ Job failures

machine

machine

switch

machine

machine

switch

switch switch

Distributed Computing for Large Scale Problems

8

✧ Distribute workload among many machines

✧ Widely available thanks to cloud providers (AWS, GCP, Azure)

✧ Challenges ✓ Limited communication bandwidth (10x

less than memory bandwidth) ✓ Large synchronization cost (1ms latency) ✓ Job failures

machine

machine

switch

machine

machine

switch

switch switch

0

8

16

24

Machine time (hour)

100 1000 10000Failu

re ra

te (%

)

Distributed Computing for Large Scale Problems

Distributed Optimization for Large Scale ML

9

m[

i=1

Ii = {1, 2, . . . , n}

X

i2I1

@fi(wt)

X

i2I2

@fi(wt)

X

i2Im

@fi(wt)

wt wt+1+

minw

nX

i=1

fi(w)

Distributed Optimization for Large Scale ML

9

m[

i=1

Ii = {1, 2, . . . , n}

X

i2I1

@fi(wt)

X

i2I2

@fi(wt)

X

i2Im

@fi(wt)

wt wt+1+

…✧ Challenges ✓ Massive communication traffic ✓ Expensive global synchronization

minw

nX

i=1

fi(w)

10

Distributed SystemsDistributed Systems

✓ Large data size, complex models ✓ Fault tolerant ✓ Easy to use

Scaling Distributed Machine Learning

Distributed SystemsLarge Scale Optimization

✓ Communication efficient ✓ Convergence guarantee

11

Distributed SystemsDistributed Systems

Scaling Distributed Machine Learning

Distributed SystemsLarge Scale Optimization

for machine learning

Parameter Server for machine learning

for machine learning

MXNet for deep learning

for machine learning

DBPG for non-convex non-smooth fi

for machine learning

EMSO for efficient minibatch SGD

12

Distributed SystemsDistributed Systems

Scaling Distributed Machine Learning

Distributed SystemsLarge Scale Optimization

for machine learning

Parameter Server for machine learning

for machine learning

MXNet for deep learning

for machine learning

DBPG for non-convex non-smooth fi

for machine learning

EMSO for efficient minibatch SGD

Existing Open Source Systems in 2012

13

✧ MPI (message passing interface) ✓ Hard to use for sparse problems ✓ No fault tolerance

✧ Key-value store, e.g. redis ✓ Expensive individual key-value pair communication ✓ Difficult to program on the server side

✧ Hadoop/Spark ✓ BSP data consistency makes efficient implementation challenging

Parameter Server Architecture

14

[Smola’10, Dean’12]

Parameter Server Architecture

14

Training data

[Smola’10, Dean’12]

Parameter Server Architecture

14

Training data

Worker machines

[Smola’10, Dean’12]

Parameter Server Architecture

14

Training data

Worker machines

Model[Smola’10, Dean’12]

Parameter Server Architecture

14

Training data

Worker machines

Server machines

Model[Smola’10, Dean’12]

Parameter Server Architecture

14

Training data

push

Worker machines

Server machines

Model[Smola’10, Dean’12]

Parameter Server Architecture

14

Training data

push

Worker machines

Server machines

Model[Smola’10, Dean’12]

update

Parameter Server Architecture

14

Training data

push

Worker machines

Server machines

pull

Model[Smola’10, Dean’12]

update

Keys Features of our Implementation

15

[Li et al, OSDI’14]

Keys Features of our Implementation

✧ Trade off data consistency for speed ✓ Flexible consistency models ✓ User-defined filters

15

[Li et al, OSDI’14]

Keys Features of our Implementation

✧ Trade off data consistency for speed ✓ Flexible consistency models ✓ User-defined filters

✧ Fault tolerance with chain replication

15

[Li et al, OSDI’14]

X

X

Flexible Consistency Model

16

gradient push & pulliter 0

gradient push & pulliter 1

execute after finished dependency

Flexible Consistency Model

16

gradient push & pulliter 0

gradient push & pulliter 1

execute after finished dependency

gradient push & pulliter 0

gradient push & pulliter 1

no dependency

Flexible Consistency Model

16

gradient push & pulliter 0

gradient push & pulliter 1

execute after finished dependency

gradient push & pulliter 0

gradient push & pulliter 1

no dependency

Flexible models via task dependency graph

Flexible Consistency Model

16

1 2 3Sequential / BSP 4

gradient push & pulliter 0

gradient push & pulliter 1

execute after finished dependency

gradient push & pulliter 0

gradient push & pulliter 1

no dependency

Flexible models via task dependency graph

Flexible Consistency Model

16

Eventual / Total asynchronous

[Smola 10]

1 2 3 4

1 2 3Sequential / BSP 4

gradient push & pulliter 0

gradient push & pulliter 1

execute after finished dependency

gradient push & pulliter 0

gradient push & pulliter 1

no dependency

Flexible models via task dependency graph

Flexible Consistency Model

16

Bounded delay / SSP [Langford 09, Cipar 13]

1 2 3 4 5

Eventual / Total asynchronous

[Smola 10]

1 2 3 4

1 2 3Sequential / BSP 4

gradient push & pulliter 0

gradient push & pulliter 1

execute after finished dependency

gradient push & pulliter 0

gradient push & pulliter 1

no dependency

Flexible models via task dependency graph

User-defined Filters

✧ User defined encoder/decoder for efficient communication

✧ Lossless compression ✓ General data compression: LZ, LZR, ..

✧ Lossy compression ✓ Random skip ✓ Fixed-point encoding

17

Encoder

Decoder

Data

Data

machine

machine

efficient communication

Fault Tolerance with Chain Replication✧ Model is partitioned by consistent hashing ✧ Chain replication

18

worker 0 server 0 server 1

push push

ackack

Fault Tolerance with Chain Replication✧ Model is partitioned by consistent hashing ✧ Chain replication

18

worker 0 server 0 server 1

push push

ackack

push

ack

push

push

ack

ackworker 0

server 0 server 1

worker 1

✧ Option: aggregation reduces backup traffic

19

Distributed SystemsDistributed Systems

Scaling Distributed Machine Learning

Distributed SystemsLarge Scale Optimization

for machine learning

Parameter Server for machine learning

for machine learning

MXNet for deep learning

for machine learning

DBPG for non-convex non-smooth fi

for machine learning

EMSO for efficient minibatch SGD

Proximal Gradient Method

✓ fi: continuously differentiable but not necessarily convex ✓ h: convex but possibly non-smooth

20

minw2⌦

nX

i=1

fi(w) + h(w)

[Combettes’09]

Proximal Gradient Method

✓ fi: continuously differentiable but not necessarily convex ✓ h: convex but possibly non-smooth

20

minw2⌦

nX

i=1

fi(w) + h(w)

✧ Iterative update

[Combettes’09]

Prox⌘(x) := argmin

y2⌦h(y) +

1

2⌘

kx� yk2

wt+1 = Prox�t

"wt � ⌘t

nX

i=1

fi(wt)

#

where

Delayed Block Proximal Gradient

✧ Algorithm design tailored for parameter server implementation ✓ Update a block of coordinates each time ✓ Allow delay among blocks ✓ Use filters during communication

✧ Only 300 lines of codes

21

[Li et al, NIPS’14]

data

model

Delayed Block Proximal Gradient

✧ Algorithm design tailored for parameter server implementation ✓ Update a block of coordinates each time ✓ Allow delay among blocks ✓ Use filters during communication

✧ Only 300 lines of codes

21

[Li et al, NIPS’14]

data

model

Convergence Analysis

✧ Assumptions: ✓ Block Lipschitz continuity: within block , cross blocks

✓ Delay is bounded by 𝜏

✓ Lossy compressions such as random skip filter and significantly-modified filter

✧ DBPG converges to a stationary point if the learning rate is chosen as

22

Lvar Lcor

⌘t <1

Lvar

+ ⌧Lcor

Experiments on Ads Click Prediction

✧ Real dataset used in production ✓ 170 billion examples, 65 billion unique

features, 636 TB in total ✧ 1000 machines ✧ Sparse logistic regression

23

min

w

nX

i=1

log(1 + exp(�yi hxi, wi)) + �kwk1

Experiments on Ads Click Prediction

✧ Real dataset used in production ✓ 170 billion examples, 65 billion unique

features, 636 TB in total ✧ 1000 machines ✧ Sparse logistic regression

23

0

0.5

1

1.5

2

Maximal delay 𝜏

0 1 2 4 8 16

computing waiting

time (hour)

Time to achieve the same objective value

min

w

nX

i=1

log(1 + exp(�yi hxi, wi)) + �kwk1

Experiments on Ads Click Prediction

✧ Real dataset used in production ✓ 170 billion examples, 65 billion unique

features, 636 TB in total ✧ 1000 machines ✧ Sparse logistic regression

23

0

0.5

1

1.5

2

Maximal delay 𝜏

0 1 2 4 8 16

computing waiting

time (hour)

sequential

Time to achieve the same objective value

min

w

nX

i=1

log(1 + exp(�yi hxi, wi)) + �kwk1

Experiments on Ads Click Prediction

✧ Real dataset used in production ✓ 170 billion examples, 65 billion unique

features, 636 TB in total ✧ 1000 machines ✧ Sparse logistic regression

23

0

0.5

1

1.5

2

Maximal delay 𝜏

0 1 2 4 8 16

computing waiting

time (hour)

sequential

Time to achieve the same objective value

min

w

nX

i=1

log(1 + exp(�yi hxi, wi)) + �kwk1

Experiments on Ads Click Prediction

✧ Real dataset used in production ✓ 170 billion examples, 65 billion unique

features, 636 TB in total ✧ 1000 machines ✧ Sparse logistic regression

23

0

0.5

1

1.5

2

Maximal delay 𝜏

0 1 2 4 8 16

computing waiting

time (hour)

sequential

best trade-off1.6x

Time to achieve the same objective value

min

w

nX

i=1

log(1 + exp(�yi hxi, wi)) + �kwk1

Filters to Reduce Communication Traffic

24

Server

Worker

Filters to Reduce Communication Traffic

24Tr

affic (

%)

0

25

50

75

100

Baseline Key Caching

Compre-ssing

KKT Filter

Traffi

c (%

)

0

25

50

75

100

Baseline Key Caching

Compre-ssing

KKT Filter

Server

Worker

Filters to Reduce Communication Traffic

✧ Key caching ✓ Cache feature IDs on both sender and

receiver

24Tr

affic (

%)

0

25

50

75

100

Baseline Key Caching

Compre-ssing

KKT Filter

Traffi

c (%

)

0

25

50

75

100

Baseline Key Caching

Compre-ssing

KKT Filter

Server

Worker

2x

2x

Filters to Reduce Communication Traffic

✧ Key caching ✓ Cache feature IDs on both sender and

receiver✧ Data compression

24Tr

affic (

%)

0

25

50

75

100

Baseline Key Caching

Compre-ssing

KKT Filter

Traffi

c (%

)

0

25

50

75

100

Baseline Key Caching

Compre-ssing

KKT Filter

Server

Worker

2x

40x

2x 2.5x

Filters to Reduce Communication Traffic

✧ Key caching ✓ Cache feature IDs on both sender and

receiver✧ Data compression✧ KKT filter ✓ Shrink gradient to 0 based on the KKT

condition

24Tr

affic (

%)

0

25

50

75

100

Baseline Key Caching

Compre-ssing

KKT Filter

Traffi

c (%

)

0

25

50

75

100

Baseline Key Caching

Compre-ssing

KKT Filter

Server

Worker

2x

40x 40x

2x 2.5x12x

25

Distributed SystemsDistributed Systems

Scaling Distributed Machine Learning

Distributed SystemsLarge Scale Optimization

for machine learning

Parameter Server for machine learning

for machine learning

MXNet for deep learning

for machine learning

DBPG for non-convex non-smooth fi

for machine learning

EMSO for efficient minibatch SGD

Deep Learning is Unique✧ Complex workloads

✧ Heterogeneous computing

✧ Easy to use programming interface

26

“deep learning” trend in the past 5 years

Key Features of MXNet

✧ Easy-to-use front-end ✓ Mixed programming

✧ Scalable and efficient back-end ✓ Computation and memory optimization ✓ Auto-parallelization ✓ Scaling to multiple GPU/machines

27

[Chen et al, NIPS’15 workshop] (corresponding author)

28

Mixed Programming

import mxnet as mx net = mx.symbol.Variable('data') net = mx.symbol.FullyConnected( data=net, num_hidden=128) net = mx.symbol.SoftmaxOutput(data=net) model = mx.module.Module(net) model.forward(data=c) model.backward()

✧ Declarative programs are easy to optimize ✓ e.g. TensorFlow, Theano, Caffe, …

Good for defining the neural network

28

✧ Imperative programming is flexible ✓ e.g. Numpy, Matlab, Torch, …

Mixed Programming

import mxnet as mx a = mx.nd.zeros((100, 50)) b = mx.nd.ones((100, 50)) c = a * b print(c) c += 1

import mxnet as mx net = mx.symbol.Variable('data') net = mx.symbol.FullyConnected( data=net, num_hidden=128) net = mx.symbol.SoftmaxOutput(data=net) model = mx.module.Module(net) model.forward(data=c) model.backward()

✧ Declarative programs are easy to optimize ✓ e.g. TensorFlow, Theano, Caffe, …

Good for defining the neural networkGood for updating and interacting

with the neural network

Back-end System

29

Back-end

import mxnet as mx a = mx.nd.zeros((100, 50)) b = mx.nd.ones((100, 50)) c = a * b c += 1

import mxnet as mx net = mx.symbol.Variable('data') net = mx.symbol.FullyConnected( data=net, num_hidden=128) net = mx.symbol.SoftmaxOutput(data=net) texec = mx.module.Module(net) texec.forward(data=c) texec.backward()

Front-end

Back-end System

29

a b

1

+

Back-end

import mxnet as mx a = mx.nd.zeros((100, 50)) b = mx.nd.ones((100, 50)) c = a * b c += 1

import mxnet as mx net = mx.symbol.Variable('data') net = mx.symbol.FullyConnected( data=net, num_hidden=128) net = mx.symbol.SoftmaxOutput(data=net) texec = mx.module.Module(net) texec.forward(data=c) texec.backward()

Front-end

Back-end System

29

a b

1

+

c

fullc

softmax

weight

bias

Back-end

import mxnet as mx a = mx.nd.zeros((100, 50)) b = mx.nd.ones((100, 50)) c = a * b c += 1

import mxnet as mx net = mx.symbol.Variable('data') net = mx.symbol.FullyConnected( data=net, num_hidden=128) net = mx.symbol.SoftmaxOutput(data=net) texec = mx.module.Module(net) texec.forward(data=c) texec.backward()

Front-end

Back-end System

29

a b

1

+

c

fullc

softmax

weight

bias

Back-end

import mxnet as mx a = mx.nd.zeros((100, 50)) b = mx.nd.ones((100, 50)) c = a * b c += 1

import mxnet as mx net = mx.symbol.Variable('data') net = mx.symbol.FullyConnected( data=net, num_hidden=128) net = mx.symbol.SoftmaxOutput(data=net) texec = mx.module.Module(net) texec.forward(data=c) texec.backward()

Front-end

Back-end System

29

a b

1

+

c

fullc

softmax

weight

bias

Back-end

import mxnet as mx a = mx.nd.zeros((100, 50)) b = mx.nd.ones((100, 50)) c = a * b c += 1

import mxnet as mx net = mx.symbol.Variable('data') net = mx.symbol.FullyConnected( data=net, num_hidden=128) net = mx.symbol.SoftmaxOutput(data=net) texec = mx.module.Module(net) texec.forward(data=c) texec.backward()

Front-end

✧ Optimization ✓ Memory optimization ✓ Operator fusion and

runtime compilation ✧ Scheduling ✓ Auto-parallelization

Scale to Multiple GPU Machines

30

PCIe Switch

GPU

GPU

GPU

GPU

CPU

Network Switch

63 GB/s 4 PCIe 3.0 16x

15.75 GB/s PCIe 3.0 16x

1.25 GB/s 10 Gbit Ethernet

Scale to Multiple GPU Machines

30

PCIe Switch

GPU

GPU

GPU

GPU

CPU

Network Switch

63 GB/s 4 PCIe 3.0 16x

15.75 GB/s PCIe 3.0 16x

1.25 GB/s 10 Gbit Ethernet

Hierarchical parameter server

Level-1 Servers

Workers

Level-2 Servers

GPUs

CPUs

✧ 1000 lines of codes

Experiment Setup

✧ ✓ 1.2 million images with 1000 classes

✧ Resnet 152-layer model ✧ EC2 P2.8 xlarge ✓ 8 K80 GPUs per machine

31

GPU 0-7

PCIe switchesCPU

Experiment Setup

✧ ✓ 1.2 million images with 1000 classes

✧ Resnet 152-layer model ✧ EC2 P2.8 xlarge ✓ 8 K80 GPUs per machine

31

GPU 0-7

PCIe switchesCPU

✧ Minibatch SGD ✓ Draw a random set of examples It at

iteration t ✓ Update

wt+1 = wt �⌘t|It|

X

i2It

@fi(wt)

✧ Synchronized updating

Communication Cost

✧ Fix #GPUs per machine

32

Communication Cost

✧ Fix #GPUs per machine

32

time

(sec

)

0.2

0.325

0.45

0.575

0.7

# of GPUs

0 32 64 96 128

1 GPU/machine2 GPU/machine4 GPU/machine8 GPU/machine

Communication Cost

✧ Fix #GPUs per machine

32

time

(sec

)

0.2

0.325

0.45

0.575

0.7

# of GPUs

0 32 64 96 128

1 GPU/machine2 GPU/machine4 GPU/machine8 GPU/machine

Communication Cost

✧ Fix #GPUs per machine

32

time

(sec

)

0.2

0.325

0.45

0.575

0.7

# of GPUs

0 32 64 96 128

1 GPU/machine2 GPU/machine4 GPU/machine8 GPU/machine

Communication Cost

✧ Fix #GPUs per machine

32

time

(sec

)

0.2

0.325

0.45

0.575

0.7

# of GPUs

0 32 64 96 128

1 GPU/machine2 GPU/machine4 GPU/machine8 GPU/machine

Scalability

33

Scalability

33

time

(sec

)

0

0.25

0.5

0.75

1

# of GPUs

0 32 64 96 128

Communication costbatch size/GPU=4batch size/GPU=8batch size/GPU=16

Scalability

33

time

(sec

)

0

0.25

0.5

0.75

1

# of GPUs

0 32 64 96 128

Communication costbatch size/GPU=4batch size/GPU=8batch size/GPU=16

Scalability

33

time

(sec

)

0

0.25

0.5

0.75

1

# of GPUs

0 32 64 96 128

Communication costbatch size/GPU=4batch size/GPU=8batch size/GPU=16

Scalability

33

time

(sec

)

0

0.25

0.5

0.75

1

# of GPUs

0 32 64 96 128

Communication costbatch size/GPU=4batch size/GPU=8batch size/GPU=16

Scalability

33

time

(sec

)

0

0.25

0.5

0.75

1

# of GPUs

0 32 64 96 128

Communication costbatch size/GPU=4batch size/GPU=8batch size/GPU=16

115x speedup

Convergence

34

Convergence

34

Top-

1 val

idat

ion

accu

racy

0.1

0.275

0.45

0.625

0.8

# of epochs

0 30 60 90 120

batch size=256batch size=2,560batch size=5,120

Convergence

✧ Increase learning rate by 5x

34

Top-

1 val

idat

ion

accu

racy

0.1

0.275

0.45

0.625

0.8

# of epochs

0 30 60 90 120

batch size=256batch size=2,560batch size=5,120

Convergence

✧ Increase learning rate by 5x

✧ Increase learning rate by 10x, decrease it at epoch 50, 80

34

Top-

1 val

idat

ion

accu

racy

0.1

0.275

0.45

0.625

0.8

# of epochs

0 30 60 90 120

batch size=256batch size=2,560batch size=5,120

35

Distributed SystemsDistributed Systems

Scaling Distributed Machine Learning

Distributed SystemsLarge Scale Optimization

for machine learning

Parameter Server for machine learning

for machine learning

MXNet for deep learning

for machine learning

DBPG for non-convex non-smooth fi

for machine learning

EMSO for efficient minibatch SGD

Minibatch SGD

36

Batch size (b)

Better system performance

Worse

Minibatch SGD

✧ Large batch size b in SGD ✓ Better parallelization within a batch ✓ Less switching/communication cost

36

Batch size (b)

Better system performance

Worse

Minibatch SGD

✧ Large batch size b in SGD ✓ Better parallelization within a batch ✓ Less switching/communication cost

✧ Small batch size b ✓ Faster convergenceN: number of examples processed

36

O(1/pN + b/N)

Batch size (b)

Better system performance

convergence rateWorse

Motivation

37

Batch size (b)

Better system performance

convergence rate

[Li et al, KDD’14]

Worse

Motivation

✧ Improve converge rate for large batch size ✓ Example variance decreases with batch size ✓ Solve a more “accurate” optimization

subproblem over each batch

37

Batch size (b)

Better system performance

convergence rate

[Li et al, KDD’14]

Worse

Efficient Minibatch SGD✧ Define . Minibatch SGD solves

38

fIt(w) :=X

i2It

fi(w)

wt = argminw2⌦

fIt(wt�1) + h@fIt(wt�1), w � wt�1i+

1

2⌘tkw � wt�1k22

Efficient Minibatch SGD✧ Define . Minibatch SGD solves

38

fIt(w) :=X

i2It

fi(w)

first-order approximation

wt = argminw2⌦

fIt(wt�1) + h@fIt(wt�1), w � wt�1i+

1

2⌘tkw � wt�1k22

�conservative penalty

Efficient Minibatch SGD✧ Define . Minibatch SGD solves

38

fIt(w) :=X

i2It

fi(w)

first-order approximation

wt = argminw2⌦

fIt(wt�1) + h@fIt(wt�1), w � wt�1i+

1

2⌘tkw � wt�1k22

�conservative penalty

wt = argminw2⌦

fIt(w) +

1

2⌘tkw � wt�1k22

�✧ EMSO solves the subproblem at each iteration

Efficient Minibatch SGD✧ Define . Minibatch SGD solves

38

fIt(w) :=X

i2It

fi(w)

first-order approximation

wt = argminw2⌦

fIt(wt�1) + h@fIt(wt�1), w � wt�1i+

1

2⌘tkw � wt�1k22

�conservative penalty

exact objective

wt = argminw2⌦

fIt(w) +

1

2⌘tkw � wt�1k22

�✧ EMSO solves the subproblem at each iteration

Efficient Minibatch SGD✧ Define . Minibatch SGD solves

38

fIt(w) :=X

i2It

fi(w)

first-order approximation

wt = argminw2⌦

fIt(wt�1) + h@fIt(wt�1), w � wt�1i+

1

2⌘tkw � wt�1k22

�conservative penalty

✧ For convex fi, choose . EMSO has convergence rate (compared to )

⌘t = O(b/pN)

O(1/pN)

O(1/pN + b/N)

exact objective

wt = argminw2⌦

fIt(w) +

1

2⌘tkw � wt�1k22

�✧ EMSO solves the subproblem at each iteration

Experiment✧ Ads click prediction with fixed run time

39

Obj

ectiv

e

0.01

0.016

0.023

0.029

0.035

Batch size

1e3 1e4 1e5

SGDEMSO

Single machine

Experiment✧ Ads click prediction with fixed run time

39O

bjec

tive

0.195

0.201

0.208

0.214

0.22

Batch size

1e3 1e4 1e5

SGDEMSO

Obj

ectiv

e

0.01

0.016

0.023

0.029

0.035

Batch size

1e3 1e4 1e5

SGDEMSO

Single machine 10 machines

Experiment✧ Ads click prediction with fixed run time

39

Extended to deep learning in [Keskar et al, arXiv’16]O

bjec

tive

0.195

0.201

0.208

0.214

0.22

Batch size

1e3 1e4 1e5

SGDEMSO

Obj

ectiv

e

0.01

0.016

0.023

0.029

0.035

Batch size

1e3 1e4 1e5

SGDEMSO

Single machine 10 machines

40

minw

nX

i=1

fi(w)

✧ Distributed systems ✧ Large scale optimization

Large-scale problems

40

minw

nX

i=1

fi(w)

✧ Distributed systems ✧ Large scale optimization

Large-scale problems

Reduce communication

cost

40

minw

nX

i=1

fi(w)

✧ Distributed systems ✧ Large scale optimization

Large-scale problems

Reduce communication

cost

✓ Communicate less ✓ Message compression ✓ Relaxed data consistency

40

minw

nX

i=1

fi(w)

✧ Distributed systems ✧ Large scale optimization

Large-scale problems

Reduce communication

costCo-design

✓ Communicate less ✓ Message compression ✓ Relaxed data consistency

40

minw

nX

i=1

fi(w)

✧ Distributed systems ✧ Large scale optimization

Large-scale problems

Reduce communication

costCo-design

With appropriate computational frameworks and algorithm design, distributed machine learning can be made simple,

fast, and scalable, both in theory and in practice.

Acknowledgement

41

with other 13 collaborators

Advisors

QQ & Alex

Committee members

Backup slides

42

Scaling to 16 GPUs in a Single Machine

43

Scaling to 16 GPUs in a Single Machine

43

time

(sec

)

0

0.25

0.5

0.75

1

# of GPUs

0 4 8 12 16

Comm Costbs/GPU=2bs/GPU=4bs/GPU=8bs/GPU=16

Scaling to 16 GPUs in a Single Machine

43

time

(sec

)

0

0.25

0.5

0.75

1

# of GPUs

0 4 8 12 16

Comm Costbs/GPU=2bs/GPU=4bs/GPU=8bs/GPU=16

Scaling to 16 GPUs in a Single Machine

43

time

(sec

)

0

0.25

0.5

0.75

1

# of GPUs

0 4 8 12 16

Comm Costbs/GPU=2bs/GPU=4bs/GPU=8bs/GPU=16

Communication dominates

Scaling to 16 GPUs in a Single Machine

43

time

(sec

)

0

0.25

0.5

0.75

1

# of GPUs

0 4 8 12 16

Comm Costbs/GPU=2bs/GPU=4bs/GPU=8bs/GPU=16

Communication dominates

15x

Compare to a L-BFGS Based System

44

System AParameter Server

0.1 1 10

Time (hour)

0

1.25

2.5

3.75

5

System A Parameter Server

computing waiting

Obj

ectiv

e

Tim

e (h

our)

Sections not Covered

✧ AdaDelay: model the actual delay for asynchronized SGD ✧ Parsa: data placement to reduce communication cost ✧ Difacto: large scale factorization machine

45

Recommended