Skip to content

Commit 8e548e6

Browse files
kdorrpalnabarun
authored andcommitted
Integrate _normalize_data with existing infrastructure
1 parent b860a02 commit 8e548e6

File tree

3 files changed

+16
-37
lines changed

3 files changed

+16
-37
lines changed

mplaltair/_convert.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import matplotlib.dates as mdates
2-
from ._data import _locate_channel_data, _locate_channel_dtype
2+
from ._data import _locate_channel_data, _locate_channel_dtype, _normalize_data
33

44
def _allowed_ranged_marks(enc_channel, mark):
55
"""TODO: DOCS
@@ -109,6 +109,8 @@ def _convert(chart):
109109
"""
110110
mapping = {}
111111

112+
_normalize_data(chart)
113+
112114
if not chart.to_dict().get('encoding'):
113115
raise ValueError("Encoding not provided with the chart specification")
114116

mplaltair/_data.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,17 @@
44

55
from ._utils import _fetch
66

7-
def _normalize_data(spec):
7+
def _normalize_data(chart):
88
"""Converts the data to a Pandas dataframe
99
1010
Parameters
1111
----------
12-
spec : dict
12+
chart : altair.Chart
1313
The vega-lite specification in json format
1414
1515
Returns
1616
-------
17-
dict
18-
The vega-lite specification with the data format fixed to a Pandas dataframe
17+
None
1918
2019
Raises
2120
------
@@ -26,20 +25,19 @@ def _normalize_data(spec):
2625
Raised when the data specification has an unsupported data source
2726
"""
2827

28+
spec = chart.to_dict()
29+
2930
if not spec.get('data'):
3031
raise ValidationError('Please specify a data source.')
3132

3233
if spec['data'].get('url'):
3334
df = pd.DataFrame(_fetch(spec['data']['url']))
3435
elif spec['data'].get('values'):
35-
df = pd.DataFrame(spec['data']['values'])
36+
return
3637
else:
3738
raise NotImplementedError('Given data specification is unsupported at the moment.')
3839

39-
del spec['data']
40-
spec['data'] = df
41-
42-
return spec
40+
chart.data = df
4341

4442
def _locate_channel_dtype(chart, channel):
4543
"""Locates dtype used for each channel

mplaltair/tests/test_data.py

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,14 @@
1717
})
1818

1919
def test_data_list():
20-
spec = {
21-
"data": {
22-
"values": [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
23-
24-
}
25-
}
26-
assert type(_normalize_data(spec)["data"]) == pd.DataFrame
20+
chart = alt.Chart(pd.DataFrame({'a': [1], 'b': [2], 'c': [3]})).mark_point()
21+
_normalize_data(chart)
22+
assert type(chart.data) == pd.DataFrame
2723

2824
def test_data_url():
29-
spec = {
30-
"data": {
31-
"url": data.cars.url
32-
}
33-
}
34-
assert type(_normalize_data(spec)["data"]) == pd.DataFrame
35-
36-
def test_data_no_pass():
37-
spec = {}
38-
with pytest.raises(ValidationError):
39-
_normalize_data(spec)
40-
41-
def test_data_invalid():
42-
spec = {
43-
"data": {
44-
"source": "path"
45-
}
46-
}
47-
with pytest.raises(NotImplementedError):
48-
_normalize_data(spec)
25+
chart = alt.Chart(data.cars.url).mark_point()
26+
_normalize_data(chart)
27+
assert type(chart.data) == pd.DataFrame
4928

5029
# _locate_channel_data() tests
5130

0 commit comments

Comments
 (0)