Upload
others
View
9
Download
0
Embed Size (px)
Citation preview
A BEGINNER’S TORCH7 TUTORIAL
BAOGUANG SHI (石葆光)
HUAZHONG UNIVERSITY OF SCIENCE AND TECHNOLOGY
VALSE 2016, APRIL 22RD, 2016
WUHAN, CHINA
WhatisTorch7?
• AmachinelearningprogrammingframeworkbuiltonLua(JIT),C++,andCUDA• Open-source,BSDlicense• Nowadayswidelyusedfordeeplearning
• GoogleDeepMind(“Human-levelcontrolthroughdeepreinforcementlearning”,Nature,2015)• FacebookAIResearch
4/24/16 3
AQuickLook
• Constructamulti-layeredperceptronmodel,andperformforwardpropagation
4/24/16 4
require('nn')
-- create an MLP modelmodel = nn.Sequential()model:add(nn.Linear(512, 256))model:add(nn.ReLU())model:add(nn.Linear(256, 10))model:add(nn.SoftMax())
-- generate a random inputinput = torch.Tensor(10, 512):uniform(-1,1)
-- forward propagationoutput = model:forward(input)
WhyUseTorch7?
• Greatflexibility
Caffe LayerC++
vector<caffe::Blob>
vector<caffe::Blob>
Torch7LayerLua/C++
tensor{tensor1,tensor2,…,tensorN}
{‘hello’,‘world’}{tensor,‘hello’,123}
{‘hello’,‘world’}{tensor1,tensor2,…,tensorN}
{tensor,‘hello’,123}tensor
4/24/16 5
WhyUseTorch7?
• Torch7isfast• LuaJIT isblazinglyfast• CUDA/C++acceleratedbackends
4/24/16 6
*imagefromhttp://julialang.org/
GettingStarted
• Prerequisites:• Linux(Ubuntu14.04x64recommended)• (optional)NVIDIAGPU,CUDA
• InstallTorch7
• Installpackages
• RunTorch7
git clone https://github.com/torch/distro.git ~/torch --recursivecd ~/torch; bash install-deps;./install.sh
luarocks install imageluarocks install cunn
th main.lua
4/24/16 8
Lua Basics
• Luaisaninterpretivelanguage,likePythonandMATLAB-- this is a commentnum = 42s = ‘hello’
for i = 1, 10 doif i % 2 == 0 thennum = num + 1
elsenum = num – 2
endend
function printNumber(n)print(string.format(‘The number is %d’, n))
end
printNumber(num)
4/24/16 9
Lua Basics
• Tables:theonly compounddatastructureinLua• Usedasdictionary
• Usedaslist
• Usedasclasscontainer(meta-table)
• Recommendedtutorial:“LearnLuain15mins”http://tylerneylon.com/a/learn-lua/
person = {id=1234, name=‘Jane Doe’, registered=true}print(person.id) -- output 1234person.age = 32 -- add a new key ’age’
v = {‘jane’, 36, ‘doe’}print(v[1]) -- output ‘jane’print(#v) -- output 3v[4] = 32 -- add a new element
4/24/16 10
Torch7CoreStructure:Tensor
• Tensor:multi-dimensionalarray• Similartonp.array inNumpy andmatrixinMATLAB• Usetensor
x = torch.Tensor(128,3,224,224)
4/24/16 11
Torch7CoreStructure:Tensor
• Datatypes• Byte:torch.ByteTensor• Integer:torch.IntTenosr• Long:torch.LongTensor• Float:torch.FloatTensor• Double:torch.DoubleTensor• GPUfloat:torch.CudaTensor
• Defaulttensor:torch.Tensor,setbytorch.setdefaulttensortype(‘...’)• Typecasting:x:type(‘torch.FloatTensor’)
4/24/16 12
TensorOperations
• Create&initializeatensor
x = torch.Tensor(5,6)x:fill(0) -- fill with zerosx:uniform(-1,1) -- fill with random values between -1 and 1
x=
4/24/16 13
TensorOperations
• Numberofdimensions,totalnumberofelements,sizeofeachdimensionx:nDimension() -- return 2x:nElements() -- returns 30x:size() -- returns torch.LongTensor({5,6})x:size(2) -- returns 6
x=
4/24/16 14
Object-OrientedProgramming
• Torch7providestorch.class forOOPlocal Rectangle, parent = torch.class(‘Rectnagle’, ‘Polygon’)
function Rectangle:__init(w, h)parent.__init(self)self.w = wself.h = h
end
function Rectangle:area() -- overridereturn self.w * self.h
end
function Rectangle:numVertices()return 4
end
4/24/16 19
Thenn Package
• ConstructneuralnetworksonCPU/GPU
• ConvertaCPUmodeltoGPU
• Mostneuralnetslayersinheritsthebaseclass:nn.Module,whosebasemethodsare:
• Alsoimplementedcriterions(lossfunctions),whichinheritnn.Criterion
4/24/16 21
Module:updateOutput(input)Module:updateGradInput(input, gradOutput)Module:accGradParameters(input, gradOutput, scale)
require(‘nn’)require(‘cunn’)
model:cuda()
Criterion:updateOutput(input, target)Criterion:updateGradInput(input, target)
Thenn Package
• Implementedabout140kindsofnetworklayers/criterion(April2016)
• Layers• Convolution:nn.SpatialConvolution, nn.SpatialFullConvolution• Activationfunctions:nn.ReLU, nn.Tanh, nn.Sigmoid, nn.PReLU• Normalization:nn.BatchNormalization, nn.SpatialBatchNormalization
• Criterions• Classification:nn.ClassNLLCriterion, nn.BCECriterion• Regression:nn.MSECriterion, nn.WeightedMSECriterion• Multi-task:nn.MultiCriterion
4/24/16 22
Thenn Package
• Networkcontainers
4/24/16 23
module-3
module-2
module-1
nn.Sequentialnn.Concat /
nn.ConcatTable
module-1
module-2
module-3
nn.Parallel /nn.ParallelTable
module-1
module-2
module-3
Thenn Package
• Networkcontainers
4/24/16 24
require('nn')
-- create an MLP modelmodel = nn.Sequential()model:add(nn.Linear(512, 256))model:add(nn.ReLU())model:add(nn.Linear(256, 10))model:add(nn.SoftMax())
-- generate a random inputinput = torch.Tensor(10, 512):uniform(-1,1)
-- forward propagationoutput = model:forward(input)
Example:BuildingaResNetBlock
• ResNetblock:Sequential+ConcatTable
4/24/16 25
local resBlock = nn.Sequential()
local branch = nn.ConcatTable()
local conv = nn.Sequential()conv:add(nn.SpatialConvolution(...))conv:add(nn.SpatialBatchNormalization(...))conv:add(nn.ReLU())branch:add(conv)
branch:add(nn.Identity())
resBlock:add(concat)resBlock:add(nn.CAddTable())
BatchNorm
Convolution
ReLU
+
Identity
Thenn Package
• Trainingnetworks(1/2)
4/24/16 26
-- create criterion (loss function)local criterion = nn.ClassNLLCriterion()
-- create optimizer (SGD)require(‘optim’)local optimizer = optim.sgd({learningRate=1e-3, momentum=0.9})local optimState = {}
-- retrieve all parameters and gradients on paramterslocal params, gradParams = model:getParameters()
Thenn Package
• Trainingnetworks(2/2)
4/24/16 27
local function trainOneBatch(inputBatch, targetBatch)model:training()local feval = function(params)gradParams:zero()local outputBatch = model:forward(inputBatch)local f = criterion:forward(outputBatch, targetBatch)model:backward(inputBatch,
criterion:backward(outputBatch, targetBatch))return f, gradParams
endoptimizer(feval, params, optimState)
end
-- train loopsfor i = 1, maxIterations dolocal inputBatch, targetBatch = trainDataset:getData()trainOneBatch(inputBatch, targetBatch)
end
Thenngraph Package
• Constructcomplexgraph-structurednetworks• LSTM• GRU• Attention-basedmodels• ...
4/24/16 28
*imagefromhttp://colah.github.io/posts/2015-08-Understanding-LSTMs/
Thenngraph Package
• Thepackagenngraph providesasimplewaytoconnectmodules
4/24/16 29
require(‘nngraph’)
-- x -> ReLU -> y1y1 = nn.ReLU()(x)
-- x -> Linear -> y2y2 = nn.Linear(4,4)(x)
-- z = y1 + y2z = nn.CAddTable()({y1, y2})
Thenngraph Package
• Createagraph-structuredunit
4/24/16 30
require(‘nngraph’)
function makeRnnUnit(nIn, nHidden)-- define input interfacelocal x = nn.Identity()()local prevH = nn.Identity()()local inputs = {prevH, x}
-- define structurelocal x2h = nn.Linear(nIn, nHidden)(x)local h2h = nn.Linear(nHidden, nHidden)(prevH)local sum = nn.CAddTable()({x2h, h2h})local h = nn.Tanh()(sum)
-- define output interfacelocal outputs = {h}return nn.gModule(inputs, outputs)
end
CreateaNewNetworkLayer
• Createanewclassthatinheritsnn.Module,andoverride3methods:• updateOutput(input), computeforwardpropagation• updateGradInput(input,gradOutput), computegradientsoninput• accGradParameters(input,gradOutput,scale), computegradientsonparameters
4/24/16 32
CreateaNewNetworkLayer
• Example:alayerthatscaleseveryinputelementbyatrainableparameter(𝑦" → 𝑤"𝑥" )
4/24/16 33
cuDNN acceleration
• Usecudnn toacceleratesomenetworklayers
• Replacenn layerswithcudnn layers• cudnn.SpatialConvolution• cudnn.ReLU• cudnn.SpatialBatchNormalization• (in beta) cudnn.RNN, cudnn.LSTM, cudnn.GRU
4/24/16 35
>> luarocks install cudnnrequire(‘cudnn’)
Debuggingincode
• Usefb.debugger (installfblualib first)fordebuggingLua code
• Afterenteringthedebuggingmode,printthevalueofavariable
• Continueandstep
4/24/16 36
debugger = require(‘fb.debugger’)
...debugger.enter()...
>> p variable-name
>> c (continue)>> n (next line, skip function calls)>> s (step into functions)
What’sMore?
• MixedLua - C/C++programming• TH:mixLua andC• THPP:mixLua andC++• lua - FFI:directlycallsClibraryfunctions
• CallPythonFunctions• fb.python
• PretrainedModels• Torch7ModelZoo:https://github.com/torch/torch7/wiki/ModelZoo• Residualnetwork:https://github.com/facebook/fb.resnet.torch
4/24/16 38