Skip to content

Commit 38ccc46

Browse files
committed
Fix tests
1 parent b7b2bd7 commit 38ccc46

File tree

6 files changed

+55
-58
lines changed

6 files changed

+55
-58
lines changed

mplaltair/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ def _handle_line(chart, ax):
6262
groups = []
6363
kwargs = {}
6464

65-
if chart.encoding['opacity']:
65+
if chart.encoding.get('opacity'):
6666
groups.append('opacity')
67-
if chart.encoding['stroke']:
67+
if chart.encoding.get('stroke'):
6868
groups.append('stroke')
69-
elif chart.encoding['color']:
69+
elif chart.encoding.get('color'):
7070
groups.append('color')
7171

7272
list_fields = lambda c, g: [chart.encoding[i].field for i in g]

mplaltair/_axis.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -229,18 +229,3 @@ def convert_axis(ax, chart):
229229
_set_tick_locator(channel, ax)
230230
_set_tick_formatter(channel, ax)
231231
_set_label_angle(channel, ax)
232-
233-
# for channel in chart.to_dict()['encoding']:
234-
# if channel in ['x', 'y']:
235-
# chart_info = {'ax': ax, 'axis': channel,
236-
# 'data': _locate_channel_data(chart, channel),
237-
# 'dtype': _locate_channel_dtype(chart, channel),
238-
# 'mark': chart.mark}
239-
#
240-
# scale_info = _locate_channel_scale(chart, channel)
241-
# axis_info = _locate_channel_axis(chart, channel)
242-
#
243-
# # _set_limits(chart_info, scale_info)
244-
# _set_tick_locator(chart_info, axis_info)
245-
# _set_tick_formatter(chart_info, axis_info)
246-
# _set_label_angle(chart_info, axis_info)

mplaltair/parse_chart.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ class ChartMetadata(object):
8585
"""
8686

8787
def __init__(self, alt_chart):
88+
89+
if not alt_chart.to_dict()['mark']:
90+
raise ValueError("Mark not provided")
91+
if not alt_chart.to_dict().get('encoding'):
92+
raise ValueError("Ranged encoding channels like x2, y2 not allowed for Mark: {}".format(alt_chart.mark))
93+
8894
_normalize_data(alt_chart)
8995
self.data = alt_chart.data
9096
self.mark = alt_chart.mark

mplaltair/tests/test_axis.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import matplotlib.pyplot as plt
44
import pandas as pd
55
from mplaltair import convert
6+
from .._axis import convert_axis
7+
from parse_chart import ChartMetadata
68
import pytest
79

810
df_quant = pd.DataFrame({
@@ -50,7 +52,8 @@ def test_axis_set_tick_formatter_fail():
5052
This test is just for temporary coverage purposes."""
5153
from .._axis import _set_tick_formatter
5254
_, ax = plt.subplots()
53-
_set_tick_formatter({'ax': ax, 'dtype': 'ordinal'}, {})
55+
chart = ChartMetadata(alt.Chart(df_quant).mark_point().encode('a:N', 'c:O'))
56+
_set_tick_formatter(chart.encoding['x'], ax)
5457

5558

5659
@pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_axis')

mplaltair/tests/test_convert.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import matplotlib.pyplot as plt
77
from mplaltair import convert
88
from mplaltair._convert import _convert
9+
from mplaltair.parse_chart import ChartMetadata
910

1011

1112
df = pd.DataFrame({
@@ -22,170 +23,175 @@
2223
})
2324

2425

25-
def test_encoding_not_provided():
26+
def test_encoding_not_provided(): # TODO: move to the parse_chart tests
2627
chart_spec = alt.Chart(df).mark_point()
2728
with pytest.raises(ValueError):
28-
_convert(chart_spec)
29+
chart = ChartMetadata(chart_spec)
30+
# _convert(chart)
2931

3032
def test_invalid_encodings():
3133
chart_spec = alt.Chart(df).encode(x2='quant').mark_point()
34+
chart = ChartMetadata(chart_spec)
3235
with pytest.raises(ValueError):
33-
_convert(chart_spec)
36+
_convert(chart)
3437

3538
@pytest.mark.xfail(raises=TypeError)
36-
def test_invalid_temporal():
39+
def test_invalid_temporal(): # TODO: move to parse_chart tests???
3740
chart = alt.Chart(df).mark_point().encode(alt.X('quant:T'))
38-
_convert(chart)
41+
ChartMetadata(chart)
42+
# _convert(chart)
3943

4044
@pytest.mark.parametrize('channel', ['quant', 'ord', 'nom'])
4145
def test_convert_x_success(channel):
4246
chart_spec = alt.Chart(df).encode(x=channel).mark_point()
43-
mapping = _convert(chart_spec)
47+
chart = ChartMetadata(chart_spec)
48+
mapping = _convert(chart)
4449
assert list(mapping['x']) == list(df[channel].values)
4550

4651
@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"])
4752
def test_convert_x_success_temporal(column):
4853
chart = alt.Chart(df).mark_point().encode(alt.X(column))
54+
chart = ChartMetadata(chart)
4955
mapping = _convert(chart)
5056
assert list(mapping['x']) == list(mdates.date2num(df[column].values))
5157

5258
def test_convert_x_fail():
53-
chart_spec = alt.Chart(df).encode(x='b:N').mark_point()
5459
with pytest.raises(KeyError):
60+
chart_spec = ChartMetadata(alt.Chart(df).encode(x='b:N').mark_point())
5561
_convert(chart_spec)
5662

5763
@pytest.mark.parametrize('channel', ['quant', 'ord', 'nom'])
5864
def test_convert_y_success(channel):
59-
chart_spec = alt.Chart(df).encode(y=channel).mark_point()
65+
chart_spec = ChartMetadata(alt.Chart(df).encode(y=channel).mark_point())
6066
mapping = _convert(chart_spec)
6167
assert list(mapping['y']) == list(df[channel].values)
6268

6369
@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"])
6470
def test_convert_y_success_temporal(column):
65-
chart = alt.Chart(df).mark_point().encode(alt.Y(column))
71+
chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.Y(column)))
6672
mapping = _convert(chart)
6773
assert list(mapping['y']) == list(mdates.date2num(df[column].values))
6874

6975
def test_convert_y_fail():
70-
chart_spec = alt.Chart(df).encode(y='b:N').mark_point()
7176
with pytest.raises(KeyError):
77+
chart_spec = ChartMetadata(alt.Chart(df).encode(y='b:N').mark_point())
7278
_convert(chart_spec)
7379

7480
@pytest.mark.xfail(raises=ValueError, reason="It doesn't make sense to have x2 and y2 on scatter plots")
7581
def test_quantitative_x2_y2():
76-
chart = alt.Chart(df_quant).mark_point().encode(alt.X('a'), alt.Y('b'), alt.X2('c'), alt.Y2('alpha'))
82+
chart = ChartMetadata(alt.Chart(df_quant).mark_point().encode(alt.X('a'), alt.Y('b'), alt.X2('c'), alt.Y2('alpha')))
7783
_convert(chart)
7884

7985
@pytest.mark.xfail(raises=ValueError)
8086
@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"])
8187
def test_convert_x2_y2_fail_temporal(column):
82-
chart = alt.Chart(df).mark_point().encode(alt.X2(column), alt.Y2(column))
88+
chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.X2(column), alt.Y2(column)))
8389
_convert(chart)
8490

8591
@pytest.mark.parametrize('channel,dtype', [('quant','quantitative'), ('ord','ordinal')])
8692
def test_convert_color_success(channel, dtype):
87-
chart_spec = alt.Chart(df).encode(color=alt.Color(field=channel, type=dtype)).mark_point()
93+
chart_spec = ChartMetadata(alt.Chart(df).encode(color=alt.Color(field=channel, type=dtype)).mark_point())
8894
mapping = _convert(chart_spec)
8995
assert list(mapping['c']) == list(df[channel].values)
9096

9197
def test_convert_color_success_nominal():
92-
chart_spec = alt.Chart(df).encode(color='nom').mark_point()
98+
chart_spec = ChartMetadata(alt.Chart(df).encode(color='nom').mark_point())
9399
with pytest.raises(NotImplementedError):
94100
_convert(chart_spec)
95101

96102
@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"])
97103
def test_convert_color_success_temporal(column):
98-
chart = alt.Chart(df).mark_point().encode(alt.Color(column))
104+
chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.Color(column)))
99105
mapping = _convert(chart)
100106
assert list(mapping['c']) == list(mdates.date2num(df[column].values))
101107

102-
def test_convert_color_fail():
103-
chart_spec = alt.Chart(df).encode(color='b:N').mark_point()
108+
def test_convert_color_fail(): # TODO: What is this covering?
104109
with pytest.raises(KeyError):
110+
chart_spec = ChartMetadata(alt.Chart(df).encode(color='b:N').mark_point())
105111
_convert(chart_spec)
106112

107113
@pytest.mark.parametrize('channel,type', [('quant', 'Q'), ('ord', 'O')])
108114
def test_convert_fill(channel, type):
109-
chart_spec = alt.Chart(df).encode(fill='{}:{}'.format(channel, type)).mark_point()
115+
chart_spec = ChartMetadata(alt.Chart(df).encode(fill='{}:{}'.format(channel, type)).mark_point())
110116
mapping = _convert(chart_spec)
111117
assert list(mapping['c']) == list(df[channel].values)
112118

113119
def test_convert_fill_success_nominal():
114-
chart_spec = alt.Chart(df).encode(fill='nom').mark_point()
120+
chart_spec = ChartMetadata(alt.Chart(df).encode(fill='nom').mark_point())
115121
with pytest.raises(NotImplementedError):
116122
_convert(chart_spec)
117123

118124
@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"])
119125
def test_convert_fill_success_temporal(column):
120-
chart = alt.Chart(df).mark_point().encode(alt.Fill(column))
126+
chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.Fill(column)))
121127
mapping = _convert(chart)
122128
assert list(mapping['c']) == list(mdates.date2num(df[column].values))
123129

124130

125-
def test_convert_fill_fail():
126-
chart_spec = alt.Chart(df).encode(fill='b:N').mark_point()
131+
def test_convert_fill_fail(): # TODO: what is this covering?
127132
with pytest.raises(KeyError):
133+
chart_spec = ChartMetadata(alt.Chart(df).encode(fill='b:N').mark_point())
128134
_convert(chart_spec)
129135

130136
@pytest.mark.xfail(raises=NotImplementedError, reason="The marker argument in scatter() cannot take arrays")
131137
def test_quantitative_shape():
132-
chart = alt.Chart(df_quant).mark_point().encode(alt.Shape('shape'))
138+
chart = ChartMetadata(alt.Chart(df_quant).mark_point().encode(alt.Shape('shape')))
133139
mapping = _convert(chart)
134140

135141
@pytest.mark.xfail(raises=NotImplementedError, reason="The marker argument in scatter() cannot take arrays")
136142
@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"])
137143
def test_convert_shape_fail_temporal(column):
138-
chart = alt.Chart(df).mark_point().encode(alt.Shape(column))
144+
chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.Shape(column)))
139145
mapping = _convert(chart)
140146

141147
@pytest.mark.xfail(raises=NotImplementedError, reason="Merge: the dtype for opacity isn't assumed to be quantitative")
142148
def test_quantitative_opacity_value():
143-
chart = alt.Chart(df_quant).mark_point().encode(opacity=alt.value(.5))
149+
chart = ChartMetadata(alt.Chart(df_quant).mark_point().encode(opacity=alt.value(.5)))
144150
mapping = _convert(chart)
145151

146152
@pytest.mark.xfail(raises=NotImplementedError, reason="The alpha argument in scatter() cannot take arrays")
147153
def test_quantitative_opacity_array():
148-
chart = alt.Chart(df_quant).mark_point().encode(alt.Opacity('alpha'))
154+
chart = ChartMetadata(alt.Chart(df_quant).mark_point().encode(alt.Opacity('alpha')))
149155
_convert(chart)
150156

151157
@pytest.mark.xfail(raises=NotImplementedError, reason="The alpha argument in scatter() cannot take arrays")
152158
@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"])
153159
def test_convert_opacity_fail_temporal(column):
154-
chart = alt.Chart(df).mark_point().encode(alt.Opacity(column))
160+
chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.Opacity(column)))
155161
_convert(chart)
156162

157163
@pytest.mark.parametrize('channel,type', [('quant', 'Q'), ('ord', 'O')])
158164
def test_convert_size_success(channel, type):
159-
chart_spec = alt.Chart(df).encode(size='{}:{}'.format(channel, type)).mark_point()
165+
chart_spec = ChartMetadata(alt.Chart(df).encode(size='{}:{}'.format(channel, type)).mark_point())
160166
mapping = _convert(chart_spec)
161167
assert list(mapping['s']) == list(df[channel].values)
162168

163169
def test_convert_size_success_nominal():
164-
chart_spec = alt.Chart(df).encode(size='nom').mark_point()
165170
with pytest.raises(NotImplementedError):
171+
chart_spec = ChartMetadata(alt.Chart(df).encode(size='nom').mark_point())
166172
_convert(chart_spec)
167173

168174
def test_convert_size_fail():
169-
chart_spec = alt.Chart(df).encode(size='b:N').mark_point()
170175
with pytest.raises(KeyError):
176+
chart_spec = ChartMetadata(alt.Chart(df).encode(size='b:N').mark_point())
171177
_convert(chart_spec)
172178

173179
@pytest.mark.xfail(raises=NotImplementedError, reason="Dates would need to be normalized for the size.")
174180
@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"])
175181
def test_convert_size_fail_temporal(column):
176-
chart = alt.Chart(df).mark_point().encode(alt.Size(column))
182+
chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.Size(column)))
177183
_convert(chart)
178184

179185

180186
@pytest.mark.xfail(raises=NotImplementedError, reason="Stroke is not well supported in Altair")
181187
def test_quantitative_stroke():
182-
chart = alt.Chart(df_quant).mark_point().encode(alt.Stroke('fill'))
188+
chart = ChartMetadata(alt.Chart(df_quant).mark_point().encode(alt.Stroke('fill')))
183189
_convert(chart)
184190

185191
@pytest.mark.xfail(raises=NotImplementedError, reason="Stroke is not well defined in Altair")
186192
@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"])
187193
def test_convert_stroke_fail_temporal(column):
188-
chart = alt.Chart(df).mark_point().encode(alt.Stroke(column))
194+
chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.Stroke(column)))
189195
_convert(chart)
190196

191197

@@ -194,12 +200,12 @@ def test_convert_stroke_fail_temporal(column):
194200
@pytest.mark.xfail(raises=NotImplementedError, reason="Aggregate functions are not supported yet")
195201
def test_quantitative_x_count_y():
196202
df_count = pd.DataFrame({"a": [1, 1, 2, 3, 5], "b": [1.4, 1.4, 2.9, 3.18, 5.3]})
197-
chart = alt.Chart(df_count).mark_point().encode(alt.X('a'), alt.Y('count()'))
203+
chart = ChartMetadata(alt.Chart(df_count).mark_point().encode(alt.X('a'), alt.Y('count()')))
198204
mapping = _convert(chart)
199205

200206
@pytest.mark.xfail(raises=NotImplementedError, reason="specifying timeUnit is not supported yet")
201207
def test_timeUnit():
202-
chart = alt.Chart(df).mark_point().encode(alt.X('date(combination)'))
208+
chart = ChartMetadata(alt.Chart(df).mark_point().encode(alt.X('date(combination)')))
203209
_convert(chart)
204210

205211
# Plots

mplaltair/tests/test_data.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ def test_data_field_quantitative(column, dtype):
3636
chart = alt.Chart(df).mark_point().encode(alt.X(field=column, type=dtype))
3737
for channel in chart.to_dict()['encoding']:
3838
data = _data._locate_channel_data(chart, channel)
39-
if dtype == 'temporal':
40-
assert list(data) == list(_data._convert_to_mpl_date(df[column].values))
41-
else:
4239
assert list(data) == list(df[column].values)
4340

4441

@@ -54,7 +51,7 @@ def test_data_shorthand_temporal():
5451
chart = alt.Chart(df).mark_point().encode(alt.X('combination'))
5552
for channel in chart.to_dict()['encoding']:
5653
data = _data._locate_channel_data(chart, channel)
57-
assert list(data) == list(_data._convert_to_mpl_date(df['combination'].values))
54+
assert list(data) == list(df['combination'].values)
5855

5956

6057
def test_data_value_quantitative():

0 commit comments

Comments
 (0)