博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
(原)torch的训练过程
阅读量:4703 次
发布时间:2019-06-10

本文共 19115 字,大约阅读时间需要 63 分钟。

转载请注明出处:

参考网址:

1. 使用updateParameters

假设已经有了model=setupmodel(自己建立的模型),同时也有自己的训练数据input,实际输出outReal,以及损失函数criterion(参见第二个网址),则使用torch训练过程如下:

1 -- given model, criterion, input, outReal2 model:training()3 model:zeroGradParameters()4 outPredict = model:forward(input)5 err= criterion:forward(outPredict, outReal)6 grad_criterion = criterion:backward(outPredict, outReal)7 model:backward(input, grad_criterion)8 model:updateParameters(learningRate)

上面第1行假定已知的参数

第2行设置为训练模式

第3行将model中每个模块保存的梯度清零(防止之前的干扰此次迭代)

第4行将输入input通过model,得到预测的输出outPredict

第5行通过损失函数计算在当前参数下模型的预测输出outPredict和实际输出outReal的误差err

第6行通过预测输出outPredict和实际输出outReal计算损失函数的梯度grad_criterion

第7行反向计算model中每个模块的梯度

第8行更新model每个模块的参数

 

每次迭代时,均需要执行第3行至第8行。

 

=========================================================

2. 使用optim

170301更新:

中给出了更方便的方式(是否方便也说不清楚),可以使用torch中的optim来更新参数(直接使用model:updateParameters的话,只能使用最简单的梯度下降算法,optmi中封装了很多算法,梯度下降,adam之类的)。

params_new, fs, ... = optim._method_(feval, params[, config][, state])

其中,param:当前参数向量(1D的tensro),在优化时会被更新

feval:用户自定义的闭包,类似于f, df/dx = feval(x)

config:一个包含算法参数(如learning rate)的table

state:包含状态变量的table

params_new:最小化函数f的新的结果参数(1D的tensor)

fs:a table of f values evaluated during the optimization, fs[#fs] is the optimized function value

注意:由于optmi需要输入数据为1D的tensor,因而需要将模型中的参数变成拉平,通过下面的函数来实现:

params, gradParams = model:getParameters()

params和gradParams均为1D的tensor。

使用上面的方法后,开始得程序可以修改为:

-- given model, criterion, input, outReal, optimStatelocal params, gradParams = model:getParameters()local function feval()    return criterion.output, gradParams endfor ...    model:training()    model:zeroGradParameters()    outPredict = model:forward(input)    err= criterion:forward(outPredict, outReal)    grad_criterion = criterion:backward(outPredict, outReal)    model:backward(input, grad_criterion)        optim.sgd(feval, params, optimState)end

 170301更新结束

=========================================================

3. 使用model:backward注意的问题

170405更新

需要注意的是,model:backward一定要和model:forward对应。

中[gradInput] backward(input, gradOutput)写着:

In general this method makes the assumption forward(input) has been called before, with the same input. This is necessary for optimization reasons. If you do not respect this rule, backward() will compute incorrect gradients.

应该是由于backward时,可能会使用forward的某些中间变量,因而backward执行前,必须先执行forward,否则中间变量和backward不对应,导致结果错误。

我这边之前的程序由于最初forward后,保存的是最后一次forward时的中间变量,因而backward时的结果总是不正确(见method5注释的那句)。

只能使用比较坑的方式解决,之前先forward,最终在backward之前,在forward一次,这样能保证结果正确(缺点就是增加了一次计算。。。),代码如method5。

说明:method1为常规的batch的方法。但是该方法对显存要求较高。因而可以使用类似caffe中的iter_size的方式,如method2的方法(和caffe中的iter_size不完全一样)。如果需要使用更多的样本,同时criterion时使用尽可能多的样本,则前两种方法均会出现问题,此时可以使用method3的方法(但是实际上method3有问题,loss收敛的很慢)。method4对method3进行了进一步的改进及测试,如果method4注释那两行,则其收敛正常,但是不注释那两行,则收敛出现问题,和method3收敛类似。method5进行了最终的改进。该程序能正常收敛。同时为了验证forward和backward要对应,将method5中注释的取消注释,同时注释掉上面一行,可以看出,其收敛很慢(和method3,4类似)。下面是各种method前10次的的收敛情况。

程序如下:

1 require 'torch'  2 require 'nn'  3 require 'optim'  4 require 'cunn'  5 require 'cutorch'  6 local mnist = require 'mnist'  7   8 local fullset = mnist.traindataset()  9 local testset = mnist.testdataset() 10  11 local trainset = { 12     size = 50000, 13     data = fullset.data[{
{
1,50000}}]:double(), 14 label = fullset.label[{
{
1,50000}}] 15 } 16 trainset.data = trainset.data - trainset.data:mean() 17 trainset.data = trainset.data:cuda() 18 trainset.label = trainset.label:cuda() 19 20 local validationset = { 21 size = 10000, 22 data = fullset.data[{
{
50001,60000}}]:double(), 23 label = fullset.label[{
{
50001,60000}}] 24 } 25 validationset.data = validationset.data - validationset.data:mean() 26 validationset.data = validationset.data:cuda() 27 validationset.label = validationset.label:cuda() 28 29 local model = nn.Sequential() 30 model:add(nn.Reshape(1, 28, 28)) 31 model:add(nn.MulConstant(1/256.0*3.2)) 32 model:add(nn.SpatialConvolutionMM(1, 20, 5, 5, 1, 1, 0, 0)) 33 model:add(nn.SpatialMaxPooling(2, 2 , 2, 2, 0, 0)) 34 model:add(nn.SpatialConvolutionMM(20, 50, 5, 5, 1, 1, 0, 0)) 35 model:add(nn.SpatialMaxPooling(2, 2 , 2, 2, 0, 0)) 36 model:add(nn.Reshape(4*4*50)) 37 model:add(nn.Linear(4*4*50, 500)) 38 model:add(nn.ReLU()) 39 model:add(nn.Linear(500, 10)) 40 model:add(nn.LogSoftMax()) 41 42 model = require('weight-init')(model, 'xavier') 43 model = model:cuda() 44 45 x, dl_dx = model:getParameters() 46 47 local criterion = nn.ClassNLLCriterion():cuda() 48 49 local sgd_params = { 50 learningRate = 1e-2, 51 learningRateDecay = 1e-4, 52 weightDecay = 1e-3, 53 momentum = 1e-4 54 } 55 56 local training = function(batchSize) 57 local current_loss = 0 58 local count = 0 59 local shuffle = torch.randperm(trainset.size) 60 batchSize = batchSize or 200 61 for t = 0, trainset.size-1, batchSize do 62 -- setup inputs and targets for batch iteration 63 local size = math.min(t + batchSize, trainset.size) - t 64 local inputs = torch.Tensor(size, 28, 28):cuda() 65 local targets = torch.Tensor(size):cuda() 66 for i = 1, size do 67 inputs[i] = trainset.data[shuffle[i+t]] 68 targets[i] = trainset.label[shuffle[i+t]] + 1 69 end 70 71 local feval = function(x_new) 72 local miniBatchSize = 20 73 if x ~= x_new then x:copy(x_new) end -- reset data 74 dl_dx:zero() 75 76 --[[ ------------------ method 1 original batch 77 local outputs = model:forward(inputs) 78 local loss = criterion:forward(outputs, targets) 79 local gradInput = criterion:backward(outputs, targets) 80 model:backward(inputs, gradInput) 81 --]] 82 83 --[[ ------------------ method 2 iter-size with batch 84 local loss = 0 85 for idx = 1, batchSize, miniBatchSize do 86 local outputs = model:forward(inputs[{
{idx, idx + miniBatchSize - 1}}]) 87 loss = loss + criterion:forward(outputs, targets[{
{idx, idx + miniBatchSize - 1}}]) 88 local gradInput = criterion:backward(outputs, targets[{
{idx, idx + miniBatchSize - 1}}]) 89 model:backward(inputs[{
{idx, idx + miniBatchSize - 1}}], gradInput) 90 end 91 dl_dx:mul(1.0 * miniBatchSize / batchSize) 92 loss = loss * miniBatchSize / batchSize 93 --]] 94 95 --[[ ------------------ method 3 mini-batch in batch 96 local outputs = torch.Tensor(batchSize, 10):zero():cuda() 97 for idx = 1, batchSize, miniBatchSize do 98 outputs[{
{idx, idx + miniBatchSize - 1}}]:copy(model:forward(inputs[{
{idx, idx + miniBatchSize - 1}}])) 99 end100 local loss = 0101 for idx = 1, batchSize, miniBatchSize do102 loss = loss + criterion:forward(outputs[{
{idx, idx + miniBatchSize - 1}}], 103 targets[{
{idx, idx + miniBatchSize - 1}}])104 end105 local gradInput = torch.Tensor(batchSize, 10):zero():cuda()106 for idx = 1, batchSize, miniBatchSize do107 gradInput[{
{idx, idx + miniBatchSize - 1}}]:copy(criterion:backward(108 outputs[{
{idx, idx + miniBatchSize - 1}}], targets[{
{idx, idx + miniBatchSize - 1}}]))109 end110 for idx = 1, batchSize, miniBatchSize do111 model:backward(inputs[{
{idx, idx + miniBatchSize - 1}}], gradInput[{
{idx, idx + miniBatchSize - 1}}])112 end113 dl_dx:mul( 1.0 * miniBatchSize / batchSize)114 loss = loss * miniBatchSize / batchSize115 --]]116 117 --[[ ------------------ method 4 mini-batch in batch118 local outputs = torch.Tensor(batchSize, 10):zero():cuda()119 local loss = 0120 local gradInput = torch.Tensor(batchSize, 10):zero():cuda()121 for idx = 1, batchSize, miniBatchSize do122 outputs[{
{idx, idx + miniBatchSize - 1}}]:copy(model:forward(inputs[{
{idx, idx + miniBatchSize - 1}}]))123 loss = loss + criterion:forward(outputs[{
{idx, idx + miniBatchSize - 1}}], 124 targets[{
{idx, idx + miniBatchSize - 1}}])125 gradInput[{
{idx, idx + miniBatchSize - 1}}]:copy(criterion:backward(126 outputs[{
{idx, idx + miniBatchSize - 1}}], targets[{
{idx, idx + miniBatchSize - 1}}]))127 -- end128 -- for idx = 1, batchSize, miniBatchSize do129 model:backward(inputs[{
{idx, idx + miniBatchSize - 1}}], gradInput[{
{idx, idx + miniBatchSize - 1}}])130 end131 132 dl_dx:mul( 1.0 * miniBatchSize / batchSize)133 loss = loss * miniBatchSize / batchSize134 --]]135 136 137 ---[[ ------------------ method 5 mini-batch in batch138 local loss = 0139 local gradInput = torch.Tensor(batchSize, 10):zero():cuda()140 141 for idx = 1, batchSize, miniBatchSize do142 local outputs = model:forward(inputs[{
{idx, idx + miniBatchSize - 1}}])143 loss = loss + criterion:forward(outputs, targets[{
{idx, idx + miniBatchSize - 1}}])144 gradInput[{
{idx, idx + miniBatchSize - 1}}]:copy(criterion:backward(outputs, targets[{
{idx, idx + miniBatchSize - 1}}]))145 end146 147 for idx = 1, batchSize, miniBatchSize do148 model:forward(inputs[{
{idx, idx + miniBatchSize - 1}}])149 --model:forward(inputs[{
{batchSize - miniBatchSize + 1, batchSize}}])150 model:backward(inputs[{
{idx, idx + miniBatchSize - 1}}], gradInput[{
{idx, idx + miniBatchSize - 1}}])151 end152 153 dl_dx:mul( 1.0 * miniBatchSize / batchSize)154 loss = loss * miniBatchSize / batchSize155 --]]156 157 return loss, dl_dx158 end159 160 _, fs = optim.sgd(feval, x, sgd_params)161 162 count = count + 1163 current_loss = current_loss + fs[1]164 end165 166 return current_loss / count -- normalize loss167 end168 169 local eval = function(dataset, batchSize)170 local count = 0171 batchSize = batchSize or 200172 173 for i = 1, dataset.size, batchSize do174 local size = math.min(i + batchSize - 1, dataset.size) - i175 local inputs = dataset.data[{
{i,i+size-1}}]:cuda()176 local targets = dataset.label[{
{i,i+size-1}}]177 local outputs = model:forward(inputs)178 local _, indices = torch.max(outputs, 2)179 indices:add(-1)180 indices = indices:cuda()181 local guessed_right = indices:eq(targets):sum()182 count = count + guessed_right183 end184 185 return count / dataset.size186 end187 188 189 local max_iters = 50190 local last_accuracy = 0191 local decreasing = 0192 local threshold = 1 -- how many deacreasing epochs we allow193 for i = 1,max_iters do194 -- timer = torch.Timer()195 196 model:training()197 local loss = training()198 199 model:evaluate()200 local accuracy = eval(validationset)201 print(string.format('Epoch: %d Current loss: %4f; validation set accu: %4f', i, loss, accuracy))202 if accuracy < last_accuracy then203 if decreasing > threshold then break end204 decreasing = decreasing + 1205 else206 decreasing = 0207 end208 last_accuracy = accuracy209 210 --print(' Time elapsed: ' .. i .. 'iter: ' .. timer:time().real .. ' seconds')211 end212 213 testset.data = testset.data:double()214 eval(testset)
View Code

weight-init.lua

1 -- 2 -- Different weight initialization methods 3 -- 4 -- > model = require('weight-init')(model, 'heuristic') 5 -- 6 require("nn") 7  8  9 -- "Efficient backprop"10 -- Yann Lecun, 199811 local function w_init_heuristic(fan_in, fan_out)12    return math.sqrt(1/(3*fan_in))13 end14 15 -- "Understanding the difficulty of training deep feedforward neural networks"16 -- Xavier Glorot, 201017 local function w_init_xavier(fan_in, fan_out)18    return math.sqrt(2/(fan_in + fan_out))19 end20 21 -- "Understanding the difficulty of training deep feedforward neural networks"22 -- Xavier Glorot, 201023 local function w_init_xavier_caffe(fan_in, fan_out)24    return math.sqrt(1/fan_in)25 end26 27 -- "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"28 -- Kaiming He, 201529 local function w_init_kaiming(fan_in, fan_out)30    return math.sqrt(4/(fan_in + fan_out))31 end32 33 local function w_init(net, arg)34    -- choose initialization method35    local method = nil36    if     arg == 'heuristic'    then method = w_init_heuristic37    elseif arg == 'xavier'       then method = w_init_xavier38    elseif arg == 'xavier_caffe' then method = w_init_xavier_caffe39    elseif arg == 'kaiming'      then method = w_init_kaiming40    else41       assert(false)42    end43 44    -- loop over all convolutional modules45    for i = 1, #net.modules do46       local m = net.modules[i]47       if m.__typename == 'nn.SpatialConvolution' then48          m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))49       elseif m.__typename == 'nn.SpatialConvolutionMM' then50          m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))51       elseif m.__typename == 'cudnn.SpatialConvolution' then52          m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))53       elseif m.__typename == 'nn.LateralConvolution' then54          m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1))55       elseif m.__typename == 'nn.VerticalConvolution' then56          m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))57       elseif m.__typename == 'nn.HorizontalConvolution' then58          m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))59       elseif m.__typename == 'nn.Linear' then60          m:reset(method(m.weight:size(2), m.weight:size(1)))61       elseif m.__typename == 'nn.TemporalConvolution' then62          m:reset(method(m.weight:size(2), m.weight:size(1)))            63       end64 65       if m.bias then66          m.bias:zero()67       end68    end69    return net70 end71 72 return w_init
View Code
Method 1Epoch: 1 Current loss: 0.616950; validation set accu: 0.920900	Epoch: 2 Current loss: 0.228665; validation set accu: 0.942400	Epoch: 3 Current loss: 0.168047; validation set accu: 0.957900	Epoch: 4 Current loss: 0.134796; validation set accu: 0.961800	Epoch: 5 Current loss: 0.113071; validation set accu: 0.966200	Epoch: 6 Current loss: 0.098782; validation set accu: 0.968800	Epoch: 7 Current loss: 0.088252; validation set accu: 0.970000	Epoch: 8 Current loss: 0.080225; validation set accu: 0.971200	Epoch: 9 Current loss: 0.073702; validation set accu: 0.972200	Epoch: 10 Current loss: 0.068171; validation set accu: 0.972400	method 2Epoch: 1 Current loss: 0.624633; validation set accu: 0.922200	Epoch: 2 Current loss: 0.238459; validation set accu: 0.945200	Epoch: 3 Current loss: 0.174089; validation set accu: 0.959000	Epoch: 4 Current loss: 0.140234; validation set accu: 0.963800	Epoch: 5 Current loss: 0.116498; validation set accu: 0.968000	Epoch: 6 Current loss: 0.101376; validation set accu: 0.968800	Epoch: 7 Current loss: 0.089484; validation set accu: 0.972600	Epoch: 8 Current loss: 0.080812; validation set accu: 0.973000	Epoch: 9 Current loss: 0.073929; validation set accu: 0.975100	Epoch: 10 Current loss: 0.068330; validation set accu: 0.975400	method 3Epoch: 1 Current loss: 2.202240; validation set accu: 0.548500	Epoch: 2 Current loss: 2.049710; validation set accu: 0.669300	Epoch: 3 Current loss: 1.993560; validation set accu: 0.728900	Epoch: 4 Current loss: 1.959818; validation set accu: 0.774500	Epoch: 5 Current loss: 1.945992; validation set accu: 0.757600	Epoch: 6 Current loss: 1.930599; validation set accu: 0.809600	Epoch: 7 Current loss: 1.911803; validation set accu: 0.837200	Epoch: 8 Current loss: 1.904754; validation set accu: 0.842100	Epoch: 9 Current loss: 1.903705; validation set accu: 0.846400	Epoch: 10 Current loss: 1.903911; validation set accu: 0.848100	method 4Epoch: 1 Current loss: 0.624240; validation set accu: 0.924900	Epoch: 2 Current loss: 0.213469; validation set accu: 0.948500	Epoch: 3 Current loss: 0.156797; validation set accu: 0.959800	Epoch: 4 Current loss: 0.126438; validation set accu: 0.963900	Epoch: 5 Current loss: 0.106664; validation set accu: 0.965900	Epoch: 6 Current loss: 0.094166; validation set accu: 0.967200	Epoch: 7 Current loss: 0.084848; validation set accu: 0.971200	Epoch: 8 Current loss: 0.077244; validation set accu: 0.971800	Epoch: 9 Current loss: 0.071417; validation set accu: 0.973300	Epoch: 10 Current loss: 0.065737; validation set accu: 0.971600	取消注释Epoch: 1 Current loss: 2.178319; validation set accu: 0.542200	Epoch: 2 Current loss: 2.031493; validation set accu: 0.648700	Epoch: 3 Current loss: 1.982282; validation set accu: 0.703700	Epoch: 4 Current loss: 1.956709; validation set accu: 0.762700	Epoch: 5 Current loss: 1.927590; validation set accu: 0.808100	Epoch: 6 Current loss: 1.924535; validation set accu: 0.817200	Epoch: 7 Current loss: 1.911364; validation set accu: 0.820100	Epoch: 8 Current loss: 1.898206; validation set accu: 0.855400	Epoch: 9 Current loss: 1.885394; validation set accu: 0.836500	Epoch: 10 Current loss: 1.880787; validation set accu: 0.870200	method 5Epoch: 1 Current loss: 0.619814; validation set accu: 0.924300	Epoch: 2 Current loss: 0.232870; validation set accu: 0.948800	Epoch: 3 Current loss: 0.172606; validation set accu: 0.954900	Epoch: 4 Current loss: 0.137763; validation set accu: 0.961800	Epoch: 5 Current loss: 0.116268; validation set accu: 0.967700	Epoch: 6 Current loss: 0.101985; validation set accu: 0.968800	Epoch: 7 Current loss: 0.091154; validation set accu: 0.970900	Epoch: 8 Current loss: 0.083219; validation set accu: 0.972700	Epoch: 9 Current loss: 0.074921; validation set accu: 0.972800	Epoch: 10 Current loss: 0.070208; validation set accu: 0.972800	取消注释,同时注释上面一行Epoch: 1 Current loss: 2.161032; validation set accu: 0.497500	Epoch: 2 Current loss: 2.027255; validation set accu: 0.690900	Epoch: 3 Current loss: 1.972939; validation set accu: 0.767600	Epoch: 4 Current loss: 1.940982; validation set accu: 0.766000	Epoch: 5 Current loss: 1.933135; validation set accu: 0.812800	Epoch: 6 Current loss: 1.913039; validation set accu: 0.799300	Epoch: 7 Current loss: 1.896871; validation set accu: 0.848800	Epoch: 8 Current loss: 1.899655; validation set accu: 0.854400	Epoch: 9 Current loss: 1.889465; validation set accu: 0.845700	Epoch: 10 Current loss: 1.878703; validation set accu: 0.846400
View Code

170301更新结束

=========================================================

 

转载于:https://www.cnblogs.com/darkknightzh/p/6221622.html

你可能感兴趣的文章
web自动化之验证码识别解决方案
查看>>
netty接收大文件的方法
查看>>
软件工程设计之四则运算
查看>>
SpringMVC @ResponseBody 406
查看>>
HDOJ---2824 The Euler function[欧拉函数]
查看>>
KMP算法
查看>>
Atlas学习之开始篇[转]
查看>>
第二章 在HTML页面里使用javaScript
查看>>
【Educational Codeforces Round 48 (Rated for Div. 2) D】Vasya And The Matrix
查看>>
正则表达式的性能评测
查看>>
CF1172B Nauuo and Circle
查看>>
CF1178D Prime Graph
查看>>
CF1190D Tokitsukaze and Strange Rectangle
查看>>
CF1202F You Are Given Some Letters...
查看>>
CF1179C Serge and Dining Room
查看>>
CF1168B Good Triple
查看>>
CF1208E Let Them Slide
查看>>
AT2000 Leftmost Ball
查看>>
CF1086E Beautiful Matrix
查看>>
在单位上班的25条建议(建议收藏)
查看>>