-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathfig.py
More file actions
102 lines (71 loc) · 2.93 KB
/
fig.py
File metadata and controls
102 lines (71 loc) · 2.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import json
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
class PlotConfig():
def __init__(self, config_name='manuscript_single', nrow=1, ncol=1, **kwargs):
path = os.path.abspath(os.path.dirname(__file__))
defaults = os.path.join(path, 'fig.json')
with open(defaults, 'r') as f:
configs = json.load(f)
config = configs.get(config_name)
for k, v in kwargs.items():
config[k] = v
self.plot_width = config.get('plot_width')
self.margin_left = config.get('margin_left')
self.margin_right = config.get('margin_right')
self.margin_bottom = config.get('margin_bottom')
self.margin_top = config.get('margin_top')
self.space_width = config.get('space_width')
self.space_height = config.get('space_height')
self.subplot_ratio = config.get('subplot_ratio')
self.ftsize = config.get('ftsize')
self.nrow = nrow
self.ncol = ncol
self.subplot_width = ( self.plot_width
- self.margin_left - self.margin_right
- ( self.ncol - 1 ) * self.space_width
) / self.ncol
self.subplot_height = self.subplot_width * self.subplot_ratio
self.plot_height = ( self.nrow * self.subplot_height
+ self.margin_bottom + self.margin_top
+ ( self.nrow -1 ) * self.space_height )
font = {'family':'serif',
'weight':'normal',
'size':self.ftsize}
# use TEX for interpreter
plt.rc('text',usetex=True)
plt.rc('text.latex',
preamble=r'\usepackage{amsmath}\usepackage{bm}')
# use serif font
plt.rc('font',**font)
#
plt.rc('xtick', direction='in')
plt.rc('ytick', direction='in')
# cm inch transfer for matplotlib
def __cm2inch(self, *tupl):
inch = 2.54
return tuple(i/inch for i in tupl)
def get_fig(self, **kwargs):
figsize = self.__cm2inch(self.plot_width, self.plot_height)
fig = plt.figure(figsize=figsize, facecolor='w', **kwargs)
return fig
def get_axes(self, fig, i=0, j=0, **kwargs):
margin_height = ( self.margin_bottom
+(self.nrow-1-i)
*(self.space_height+self.subplot_height))
margin_width = ( self.margin_left
+j*(self.space_width+self.subplot_width))
rect = (margin_width/self.plot_width,
margin_height/self.plot_height,
self.subplot_width/self.plot_width,
self.subplot_height/self.plot_height)
ax = fig.add_axes(rect, **kwargs)
return ax
def get_simple(self):
fig = self.get_fig()
ax = self.get_axes(fig)
return fig, ax
def close(fig):
plt.close(fig=fig)
return