14
Learning Accurate Low-bit Deep Neural Networks with Stochastic Quantization Yinpeng Dong 1 , Renkun Ni 2 , Jianguo Li 3 , Yurong Chen 3 , Jun Zhu 1 , Hang Su 1 1 Department of CST, Tsinghua University 2 University of Virginia 3 Intel Labs China

LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

  • Upload
    others

  • View
    1

  • Download
    0

Embed Size (px)

Citation preview

Page 1: LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

Learning Accurate Low-bitDeep Neural Networks withStochastic Quantization

Yinpeng Dong1, Renkun Ni2,Jianguo Li3,Yurong Chen3, Jun Zhu1, Hang Su1

1Department of CST, Tsinghua University2UniversityofVirginia

3 IntelLabsChina

Page 2: LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

2

DeepLearningisEverywhere

Self-Driving AlphaGo

DotaMachine Translation

Page 3: LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

3

Limitations

n More data + deeper modelsàmore FLOPs + lagermemory

n ComputationIntensiven MemoryIntensiven Hardtodeployonmobiledevices

Page 4: LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

4

Low-bit DNNsforEfficientInference

n High Redundancy in DNNs;n Quantize full-precision(32-bits) weights to binary(1 bit)

or ternary(2 bits) weights;n Replace multiplication(convolution) by addition and

subtraction;

Page 5: LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

5

TypicalLow-bit DNNs

n BinaryConnect:

𝐵" = $+1withprobability𝑝 = 𝜎(𝑊")−1withprobability1 − 𝑝

n BWN: minimize 𝑊 − 𝛼𝐵

𝐵" = 𝑠𝑖𝑔𝑛 𝑊" , 𝛼 =∑ 𝑊"@"AB𝑑

n TWN: minimize 𝑊 − 𝛼𝑇

𝑇" = E+1if𝑊" > ∆0if 𝑊" < ∆−1if𝑊" < −∆

, 𝛼 =∑ 𝑊"�"∈M∆𝐼∆

𝐼∆ = 𝑖 𝑊" > ∆ , ∆=0.7𝑑 Q 𝑊"

@

"AB

Page 6: LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

6

Training & Inference ofLow-bitDNN

n Let𝑊 be the full-precision weights, 𝑄 be the low-bitweights (𝐵, 𝑇, α𝐵, α𝑇).

n Forward propagation: quantize𝑊 to 𝑄 and performconvolution or multiplication

n Backward propagation: use 𝑄 to calculate gradients

n Parameter update:𝑊TUB = 𝑊T − 𝜂T WXWYZ

n Inference: only need to keep low-bit weights 𝑄

Page 7: LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

7

Motivations

n Quantize all weights simultaneously;n Quantization error 𝑊 −𝑄 may be large for some

elements/filters;n Induce inappropriate gradient directions.

n Quantize a portion of weightsn Stochastic selectionn Could be applied to any low-bit settings

Page 8: LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

8

Roulette Selection Algorithm

1.3 -1.1 0.75 0.85

0.95

1.4

-1.2

-0.9 1.05 -1.0

-0.9

0.8

-0.8 0.9

1.0 -1.0

0.2

0.05

0.2

0.1

Selection Point

C1

C2

C3

C4 1-st selection: v=0.58C2 selected

Rotation

Selection Point

2-nd selection: v=0.37C3 selected

Rotation1.3 -1.1 0.75 0.85

1

1

-1.2

-1 1 -1

-1

0.8

-1 1

1.0 -1.0

Weight Matrix Quantization Error Stochastic Partition with r = 50% Hybrid Weight Matrix

𝑒" =𝑊" − 𝑄" B𝑊" B

Quantization Error:

Quantization Probability: Larger quantization errormeans smaller quantization probability, e.g. 𝑝" ∝

B]^

Quantization Ratio r: Gradually increase to 100%

Page 9: LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

9

Training & Inference

n Hybrid weight matrix 𝑄_

𝑄_" = $𝑄"ifchannelibeingselected𝑊"elsen Parameter update

𝑊TUB = 𝑊T − 𝜂T𝜕𝐿𝜕𝑄_T

n Inference: all weights are quantized; use 𝑄 to performinference

Page 10: LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

10

AblationStudies

n SelectionGranularity:¨ Filter-level >Element-level

n Selection/partitionalgorithms¨ Stochastic (roulette)>deterministic(sorting)~fixed(selectiononlyatfirstiteration)

n Quantizationfunctions¨ Linear >Sigmoid>Constant~Softmax

n 𝑝" = exp(𝑓") ∑ exp(𝑓")��⁄ , where 𝑓 = B

]

n QuantizationRatioUpdateScheme¨ Exponential >Fine-tune>Uniformly

n 50%à 75%à 87.5%à 100%

Page 11: LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

11

Results -- CIFARDONG ET AL.: STOCHASTIC QUANTIZATION 9

Bits CIFAR-10 CIFAR-100VGG-9 ResNet-56 VGG-9 ResNet-56

FWN 32 9.00 6.69 30.68 29.49BWN 1 10.67 16.42 37.68 35.01

SQ-BWN 1 9.40 7.15 35.25 31.56TWN 2 9.87 7.64 34.80 32.09

SQ-TWN 2 8.37 6.20 34.24 28.90Table 5: Test error (%) of VGG-9 and ResNet-56 trained with 5 different methods on the CIFAR-10 and CIFAR-100datasets. Our SQ algorithm can consistently improve the results. SQ-TWN even outperform full-precision models.

Bits AlexNet-BN ResNet-18top-1 top-5 top-1 top-5

FWN 32 44.18 20.83 34.80 13.60BWN 1 51.22 27.18 45.20 21.08

SQ-BWN 1 48.78 24.86 41.64 18.35TWN 2 47.54 23.81 39.83 17.02

SQ-TWN 2 44.70 21.40 36.18 14.26Table 6: Test error (%) of AlexNet-BN and ResNet-18 trained with 5 different methods on the ImageNet dataset.

simplicity. We find that the performance of different functions are very close, which indicatesthat what matters most is the stochastic partition algorithm itself.

Scheme for updating the SQ ratio. We also study how to design the scheme for updatingthe SQ ratio. Basically, we assume that training with stochastic quantization will be dividedinto several stages, and each stage has a fixed SQ ratio. In all previous setting, we use fourstages with SQ ratio r = 50%,75%,87.5% and 100%. We call this exponential scheme (Exp)since in each stage, the non-quantized part is half-sized to that of previous stage. We com-pare our results with two additional schemes: the average scheme (Ave), and the fine-tunedscheme (Tune). The average scheme includes five stages with r = 20%,40%,60%,80% and100%. The fine-tuned scheme tries to fine-tune the pre-trained full-precision models withstages same as the exponential scheme instead of training from scratch.

Table 3 shows the results of the compared three schemes. It is obvious that the exponen-tial scheme performs better than the average scheme and the fine-tuned scheme. It can alsobe concluded that our method works well when training from scratch.

4.3 Benchmark Results

We make a full benchmark comparison on the CIFAR-10, CIFAR-100 and ImageNet datasetsbetween the proposed SQ algorithm and traditional algorithms on BWN and TWN. For faircomparison, we set the learning rate, optimization method, batch size etc identical in all theseexperiments. We adopt the best setting (i.e., channel-wise selection, stochastic partition,linear probability function, exponential scheme for SQ ratio) for all SQ based experiments.

Table 5 presents the results on the CIFAR-10 and CIFAR-100 datasets. In these twocases, SQ-BWN and SQ-TWN outperform BWN and TWN significantly, especially in ResNet-56 which are deeper and the gradients can be easily misled by quantized weights with largequantization errors. For example, SQ-BWN improve the accuracy by 9.27% than BWN onthe CIFAR-10 dataset with the ResNet-56 network. The 2-bits SQ-TWN models can evenobtain higher accuracy than the full-precision models. Our results show that SQ-TWN im-prove the accuracy by 0.63%, 0.49% and 0.59% than full-precision models on CIFAR-10with VGG-9, CIFAR-10 with ResNet-56 and CIFAR-100 with ResNet-56, respectively.

10 DONG ET AL.: STOCHASTIC QUANTIZATION

0

20

40

60

80

0 64 128 192 256

Loss

Iter.(k)

FWNBWN

SQ-BWN

Figure 2: Test loss of FWN, BWN and SQ-BWN onCIFAR-10 with ResNet-56.

0

0.2

0.4

0.6

0.8

1

1.2

1.4

1.6

1.8

2

0 64 128 192 256

Loss

Iter.(k)

FWNTWN

SQ-TWN

Figure 3: Test loss of FWN, TWN and SQ-TWN onCIFAR-10 with ResNet-56.

In Figure 2 and Figure 3, we show the curves of test loss on the CIFAR-10 dataset withthe ResNet-56 network. In Figure 2, BWN doesn’t converge well with much larger loss thanSQ-BWN, while the losses in SQ-BWN and FWN are relatively low. In Figure 3, we can seethat at the beginning of each stage, quantizing more weights leads to large loss, which willbe soon converged after some training iterations. Finally, SQ-TWN gets smaller loss thanFWN, while the loss in TWN is larger than FWN.

Table 6 shows the results on the standard ImageNet 50K validation set. We compareSQ-BWN, SQ-TWN to FWN, BWN and TWN with the AlexNet-BN and the ResNet-18architectures, and all the errors are reported with only single center crop testing. It showsthat our algorithm helps to improve the performance quite a lot, which consistently beats thebaseline methods by a large margin. We can also see that SQ-TWN yield approaching resultswith FWN. Note that some works [18, 24] keep the first or last layer full-precision to alleviatepossible accuracy loss. Although we did not do like that, SQ-TWN still achieves near full-precision accuracy. Nevertheless, we conclude that our SQ algorithm can consistently andsignificantly improve the performance.

5 Conclusion

In this paper, we propose a Stochastic Quantization (SQ) algorithm to learn accurate low-bitDNNs. We propose a roulette based partition algorithm to select a portion of weights to quan-tize, while keeping the other portion unchanged with full-precision, at each iteration. The hy-brid weights provide much more appropriate gradients and lead to better local minimum. Weempirically prove the effectiveness of our algorithm by extensive experiments on various lowbitwidth settings, network architectures and benchmark datasets. We also make our codespublic at https://github.com/dongyp13/Stochastic-Quantization. Fu-ture direction may consider the stochastic quantization of both weights and activations.

Acknowledgement

This work was done when Yinpeng Dong and Renkun Ni were interns at Intel Labs su-pervised by Jianguo Li. Yinpeng Dong, Jun Zhu and Hang Su are also supported by theNational Basic Research Program of China (2013CB329403), the National Natural ScienceFoundation of China (61620106010, 61621136008).

Page 12: LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

12

Results -- ImageNet

DONG ET AL.: STOCHASTIC QUANTIZATION 9

Bits CIFAR-10 CIFAR-100VGG-9 ResNet-56 VGG-9 ResNet-56

FWN 32 9.00 6.69 30.68 29.49BWN 1 10.67 16.42 37.68 35.01

SQ-BWN 1 9.40 7.15 35.25 31.56TWN 2 9.87 7.64 34.80 32.09

SQ-TWN 2 8.37 6.20 34.24 28.90Table 5: Test error (%) of VGG-9 and ResNet-56 trained with 5 different methods on the CIFAR-10 and CIFAR-100datasets. Our SQ algorithm can consistently improve the results. SQ-TWN even outperform full-precision models.

Bits AlexNet-BN ResNet-18top-1 top-5 top-1 top-5

FWN 32 44.18 20.83 34.80 13.60BWN 1 51.22 27.18 45.20 21.08

SQ-BWN 1 48.78 24.86 41.64 18.35TWN 2 47.54 23.81 39.83 17.02

SQ-TWN 2 44.70 21.40 36.18 14.26Table 6: Test error (%) of AlexNet-BN and ResNet-18 trained with 5 different methods on the ImageNet dataset.

simplicity. We find that the performance of different functions are very close, which indicatesthat what matters most is the stochastic partition algorithm itself.

Scheme for updating the SQ ratio. We also study how to design the scheme for updatingthe SQ ratio. Basically, we assume that training with stochastic quantization will be dividedinto several stages, and each stage has a fixed SQ ratio. In all previous setting, we use fourstages with SQ ratio r = 50%,75%,87.5% and 100%. We call this exponential scheme (Exp)since in each stage, the non-quantized part is half-sized to that of previous stage. We com-pare our results with two additional schemes: the average scheme (Ave), and the fine-tunedscheme (Tune). The average scheme includes five stages with r = 20%,40%,60%,80% and100%. The fine-tuned scheme tries to fine-tune the pre-trained full-precision models withstages same as the exponential scheme instead of training from scratch.

Table 3 shows the results of the compared three schemes. It is obvious that the exponen-tial scheme performs better than the average scheme and the fine-tuned scheme. It can alsobe concluded that our method works well when training from scratch.

4.3 Benchmark Results

We make a full benchmark comparison on the CIFAR-10, CIFAR-100 and ImageNet datasetsbetween the proposed SQ algorithm and traditional algorithms on BWN and TWN. For faircomparison, we set the learning rate, optimization method, batch size etc identical in all theseexperiments. We adopt the best setting (i.e., channel-wise selection, stochastic partition,linear probability function, exponential scheme for SQ ratio) for all SQ based experiments.

Table 5 presents the results on the CIFAR-10 and CIFAR-100 datasets. In these twocases, SQ-BWN and SQ-TWN outperform BWN and TWN significantly, especially in ResNet-56 which are deeper and the gradients can be easily misled by quantized weights with largequantization errors. For example, SQ-BWN improve the accuracy by 9.27% than BWN onthe CIFAR-10 dataset with the ResNet-56 network. The 2-bits SQ-TWN models can evenobtain higher accuracy than the full-precision models. Our results show that SQ-TWN im-prove the accuracy by 0.63%, 0.49% and 0.59% than full-precision models on CIFAR-10with VGG-9, CIFAR-10 with ResNet-56 and CIFAR-100 with ResNet-56, respectively.

Page 13: LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

13

Conclusions

n We propose a stochastic quantization algorithmforLow-bitDNNtraining

n Our algorithm can be flexibly applied to all low-bitsettings;

n Our algorithm help to consistently improve theperformance;

n We releaseourcodes topublic for future development¨ https://github.com/dongyp13/Stochastic-Quantization

Page 14: LearningAccurate Low-bit Deep Neural Networks with ...ml.cs.tsinghua.edu.cn/~yinpeng/slides/BMVC17.pdf · Deep Learning is Everywhere Self-Driving Alpha Go MachineTranslation Dota

Q & A