-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDataPrepare.py
More file actions
107 lines (83 loc) · 3.15 KB
/
DataPrepare.py
File metadata and controls
107 lines (83 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import json
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
class InputData(Dataset):
"""
Style & Sketch
"""
def __init__(self, datadir='./data', mode='train',transform=None, colorhisto=True, size=500):
"""
:param datadir: data directory
:param mode: train / validate
:param transform: process image
:param colorhisto: whether need to caculate color histogram
:param size: image size limitation
"""
datadir = datadir + '/' + mode
self.datadir = datadir
self.mode = mode
self.transform = transform
self.colorhisto = colorhisto
self.size = size
self.filelist = []
for i in os.listdir(datadir):
if os.path.splitext(i)[-1] == ".png":
self.filelist.append(i)
def getData(self, index):
"""
:param index: file index in the filelist
:return: Original image, Sketch, ColorHistogram
"""
file = self.filelist[index]
filepath = self.datadir + '/' + file
filename = os.path.splitext(file)[0]
ColorHistogram = None
if self.colorhisto:
with open(('./data/colorgram/' + filename +'.json'),'r') as colorhisto_file:
colorh = json.loads(colorhisto_file.read())
# Make colorgram to sensor
colors = list(colorh.values())
topnum = len(colors[0].keys())
tensor = np.ones([topnum*3, self.size, self.size])
h = self.size // 4
for i in range(len(colors)):
hi = h * i
color = colors[i]
for j in range(1, topnum + 1):
r, g, b = color[str(j)]
# assign index
red = (j - 1) * 3
green = (j - 1) * 3 + 1
blue = (j - 1) * 3 + 2
# assign values
tensor[red, hi:hi + h] *= r
tensor[green, hi:hi + h] *= g
tensor[blue, hi:hi + h] *= b
tensor = torch.from_numpy(tensor.copy())
# TODO: Scale?
ColorHistogram = (tensor / 255.)
# Get Original Image and Sketch
image = Image.open(filepath)
width, height = image.size
OriginalImage =image.crop((0,0,width//2,height))
SketchImage = image.crop((width//2,0,width,height))
# Resize
tmpwidth_pad = max(self.size - width // 2, 0) // 2 + 1
tmpheight_pad = max(self.size - height, 0) // 2 + 1
ResizeImage = transforms.Compose([
# Left & Right, Top & Bottom
transforms.Pad((tmpwidth_pad, tmpheight_pad), 255),
transforms.CenterCrop(self.size),
])
OriginalImage = ResizeImage(OriginalImage)
SketchImage = ResizeImage(SketchImage)
# Transform Image
if self.transform:
OriginalImage = self.transform(OriginalImage)
SketchImage = self.transform(SketchImage)
# Todo: Why scale?
return OriginalImage, SketchImage, ColorHistogram