Skip to content

Commit fec2ad3

Browse files
PR review fixes
1 parent 28840bf commit fec2ad3

File tree

2 files changed

+42
-41
lines changed

2 files changed

+42
-41
lines changed

pyspark_datasources/jsonplaceholder.py

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Dict, Any, List, Iterator
22
import requests
3+
import logging
34
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
45
from pyspark.sql.types import StructType
56
from pyspark.sql import Row
@@ -76,26 +77,22 @@ def __init__(self, options=None):
7677
self.options = options or {}
7778

7879
def schema(self) -> str:
79-
endpoint = self.options.get("endpoint", "posts")
80+
""" Returns the schema for the selected endpoint."""
81+
schemas = {
82+
"posts": "userId INT, id INT, title STRING, body STRING",
83+
"users": ("id INT, name STRING, username STRING, email STRING, phone STRING, "
84+
"website STRING, address_street STRING, address_suite STRING, "
85+
"address_city STRING, address_zipcode STRING, address_geo_lat STRING, "
86+
"address_geo_lng STRING, company_name STRING, company_catchPhrase STRING, "
87+
"company_bs STRING"),
88+
"todos": "userId INT, id INT, title STRING, completed BOOLEAN",
89+
"comments": "postId INT, id INT, name STRING, email STRING, body STRING",
90+
"albums": "userId INT, id INT, title STRING",
91+
"photos": "albumId INT, id INT, title STRING, url STRING, thumbnailUrl STRING"
92+
}
8093

81-
if endpoint == "posts":
82-
return "userId INT, id INT, title STRING, body STRING"
83-
elif endpoint == "users":
84-
return ("id INT, name STRING, username STRING, email STRING, phone STRING, "
85-
"website STRING, address_street STRING, address_suite STRING, "
86-
"address_city STRING, address_zipcode STRING, address_geo_lat STRING, "
87-
"address_geo_lng STRING, company_name STRING, company_catchPhrase STRING, "
88-
"company_bs STRING")
89-
elif endpoint == "todos":
90-
return "userId INT, id INT, title STRING, completed BOOLEAN"
91-
elif endpoint == "comments":
92-
return "postId INT, id INT, name STRING, email STRING, body STRING"
93-
elif endpoint == "albums":
94-
return "userId INT, id INT, title STRING"
95-
elif endpoint == "photos":
96-
return "albumId INT, id INT, title STRING, url STRING, thumbnailUrl STRING"
97-
else:
98-
return "userId INT, id INT, title STRING, body STRING"
94+
endpoint = self.options.get("endpoint", "posts")
95+
return schemas.get(endpoint, schemas["posts"])
9996

10097
def reader(self, schema: StructType) -> DataSourceReader:
10198
return JSONPlaceholderReader(self.options)
@@ -136,28 +133,30 @@ def read(self, partition: InputPartition) -> Iterator[Row]:
136133
elif not isinstance(data, list):
137134
data = []
138135

139-
processed_data = []
140-
for item in data:
141-
processed_item = self._process_item(item)
142-
processed_data.append(processed_item)
143-
144-
return iter(processed_data)
136+
return iter([self._process_item(item) for item in data])
145137

146-
except Exception:
138+
except requests.RequestException as e:
139+
logging.warning(f"Failed to fetch data from {url}: {e}")
140+
return iter([])
141+
except ValueError as e:
142+
logging.warning(f"Failed to parse JSON from {url}: {e}")
143+
return iter([])
144+
except Exception as e:
145+
logging.error(f"Unexpected error while reading data: {e}")
147146
return iter([])
148147

149148
def _process_item(self, item: Dict[str, Any]) -> Row:
150149
"""Process individual items based on endpoint type"""
151150

152-
if self.endpoint == "posts":
151+
def _process_posts(item):
153152
return Row(
154153
userId=item.get("userId"),
155154
id=item.get("id"),
156155
title=item.get("title", ""),
157156
body=item.get("body", "")
158157
)
159158

160-
elif self.endpoint == "users":
159+
def _process_users(item):
161160
address = item.get("address", {})
162161
geo = address.get("geo", {})
163162
company = item.get("company", {})
@@ -180,15 +179,15 @@ def _process_item(self, item: Dict[str, Any]) -> Row:
180179
company_bs=company.get("bs", "")
181180
)
182181

183-
elif self.endpoint == "todos":
182+
def _process_todos(item):
184183
return Row(
185184
userId=item.get("userId"),
186185
id=item.get("id"),
187186
title=item.get("title", ""),
188187
completed=item.get("completed", False)
189188
)
190189

191-
elif self.endpoint == "comments":
190+
def _process_comments(item):
192191
return Row(
193192
postId=item.get("postId"),
194193
id=item.get("id"),
@@ -197,14 +196,14 @@ def _process_item(self, item: Dict[str, Any]) -> Row:
197196
body=item.get("body", "")
198197
)
199198

200-
elif self.endpoint == "albums":
199+
def _process_albums(item):
201200
return Row(
202201
userId=item.get("userId"),
203202
id=item.get("id"),
204203
title=item.get("title", "")
205204
)
206205

207-
elif self.endpoint == "photos":
206+
def _process_photos(item):
208207
return Row(
209208
albumId=item.get("albumId"),
210209
id=item.get("id"),
@@ -213,10 +212,14 @@ def _process_item(self, item: Dict[str, Any]) -> Row:
213212
thumbnailUrl=item.get("thumbnailUrl", "")
214213
)
215214

216-
else:
217-
return Row(
218-
userId=item.get("userId"),
219-
id=item.get("id"),
220-
title=item.get("title", ""),
221-
body=item.get("body", "")
222-
)
215+
processors = {
216+
"posts": _process_posts,
217+
"users": _process_users,
218+
"todos": _process_todos,
219+
"comments": _process_comments,
220+
"albums": _process_albums,
221+
"photos": _process_photos
222+
}
223+
224+
processor = processors.get(self.endpoint, _process_posts)
225+
return processor(item)

tests/test_data_sources.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,12 @@ def test_opensky_datasource_stream(spark):
6666
assert result.count() > 0 # Verify we got some data
6767

6868
def test_jsonplaceholder_posts():
69-
from pyspark_datasources.jsonplaceholder import JSONPlaceholderDataSource
7069
spark.dataSource.register(JSONPlaceholderDataSource)
7170
posts_df = spark.read.format("jsonplaceholder").option("endpoint", "posts").load()
7271
assert posts_df.count() > 0 # Ensure we have some posts
7372

7473

7574
def test_jsonplaceholder_referential_integrity():
76-
from pyspark_datasources.jsonplaceholder import JSONPlaceholderDataSource
7775
spark.dataSource.register(JSONPlaceholderDataSource)
7876
users_df = spark.read.format("jsonplaceholder").option("endpoint", "users").load()
7977
posts_df = spark.read.format("jsonplaceholder").option("endpoint", "posts").load()

0 commit comments

Comments
 (0)