-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathmodel.lua
More file actions
73 lines (65 loc) · 2.31 KB
/
model.lua
File metadata and controls
73 lines (65 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
require 'nn'
require 'cunn'
require 'optim'
require 'cudnn'
-- Create Network
-- If preloading option is set, preload weights from existing models appropriately
if opt.retrain ~= 'none' then
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
print('Loading model from file: ' .. opt.retrain);
model = torch.load(opt.retrain)
else
paths.dofile('models/' .. opt.netType .. '.lua')
print('=> Creating model from file: models/' .. opt.netType .. '.lua')
model = createModel() -- for the model creation code, check the models/ folder
if opt.backend == 'cudnn' then
require 'cudnn'
cudnn.convert(model, cudnn)
elseif opt.backend == 'cunn' then
require 'cunn'
model = model:cuda()
elseif opt.backend ~= 'nn' then
error'Unsupported backend'
end
end
-- Criterion
criterion = nn.ClassNLLCriterion()
-- Finetuning
if(opt.finetune == 'last' or opt.finetune == 'whole') then
-- remove last two layers and add new ones (for new nClasses)
-- either freeze previous layers (last) or train all layers (whole)
local n_units = model:get(20):parameters()[2]:size()[1] -- 2048
print('=> Model: Removing last two layers.')
model:remove()
model:remove()
if(opt.finetune == 'last') then
features = model:clone()
features:evaluate()
if(opt.lastlayer == 'none') then
print('=> Model: Adding fc layer and logsoftmax to train only last layer.')
model = nn.Sequential()
model:add(nn.Linear(n_units, nClasses))
model:add(nn.LogSoftMax())
else
print('=> Model: Adding pre-trained fc layer to train only last layer.')
model = torch.load(opt.lastlayer)
end
elseif(opt.finetune == 'whole') then
if(opt.lastlayer == 'none') then
print('=> Model: Adding fc layer and logsoftmax to train whole network.')
model:add(nn.Linear(n_units, nClasses))
model:add(nn.LogSoftMax())
else
print('=> Model: Adding pre-trained fc layer to train whole network.')
local classifier = torch.load(opt.lastlayer)
model:add(classifier)
end
end
end
print('=> Model')
print(model)
print('=> Criterion')
print(criterion)
model = model:cuda()
criterion:cuda()
collectgarbage()