forked from kashizui/deep-colorization
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathforward.lua
More file actions
66 lines (58 loc) · 1.98 KB
/
forward.lua
File metadata and controls
66 lines (58 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
require 'image'
require 'cudnn'
require 'cunn'
dofile './provider.lua'
opt = lapp[[
-l,--logs (default "/mnt/logs/colorize") subdirectory to read logs
-p,--provider (default "/mnt/provider.t7") provider
-i,--image (default "100") index into image
-t,--test use test set
]]
-- init
opt.image = tonumber(opt.image)
provider = torch.load(opt.provider)
model = torch.load(paths.concat(opt.logs, 'model.net'))
trainData = provider.trainData
if opt.test then
data = provider.testData
else
data = provider.trainData
end
-- rescuscitate original color image
yuvTrue = data.data[opt.image]
yuvTrue:select(1,1):div(256)
yuvTrue:select(1,2):mul(trainData.std_u)
yuvTrue:select(1,2):add(trainData.mean_u)
yuvTrue:select(1,2):div(256)
yuvTrue:select(1,3):mul(trainData.std_v)
yuvTrue:select(1,3):add(trainData.mean_v)
yuvTrue:select(1,3):div(256)
rgbTrue = image.yuv2rgb(yuvTrue)
-- get original unnormalized gray channel
grayOrig = torch.CudaTensor(1, 32, 32)
grayOrig:copy(yuvTrue:index(1, torch.LongTensor{1}):float())
-- get normalized gray image
gray = torch.CudaTensor(1, 1, 32, 32)
gray:copy(data.gray:index(1, torch.LongTensor{opt.image}):float())
-- forward through network
uvPred = model:forward(gray)
yuvPred = torch.cat(grayOrig, uvPred[1], 1) -- cat along channel dimension and pull out single test point
yuvPred:select(1,2):mul(trainData.std_u)
yuvPred:select(1,2):add(trainData.mean_u)
yuvPred:select(1,2):div(256)
yuvPred:select(1,3):mul(trainData.std_v)
yuvPred:select(1,3):add(trainData.mean_v)
yuvPred:select(1,3):div(256)
rgbPred = image.yuv2rgb(yuvPred)
print(gray:size())
print(rgbPred:size())
print(rgbTrue:size())
-- Save image
if opt.test then
basepath = paths.concat(opt.logs, 'test' .. opt.image)
else
basepath = paths.concat(opt.logs, opt.image)
end
image.save(basepath .. 'in.png', grayOrig)
image.save(basepath .. 'out.png', rgbPred)
image.save(basepath .. 'true.png', rgbTrue)