Skip to content

Commit b43d395

Browse files
Add return type dataclassreader (#57)
* Update dataclass_reader.py Make DataclassReader a Generic class whose type is defined by the type of dataclass supplied to the init function * Fix formatting * Make a generic class --------- Co-authored-by: Daniel Furtado <jazzmachine77@gmail.com>
1 parent 5501a8f commit b43d395

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

dataclass_csv/dataclass_reader.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import csv
33

44
from datetime import date, datetime
5-
from typing import Union, Type, Optional, Sequence, Dict, Any
5+
from typing import Union, Type, Optional, Sequence, Dict, Any, List, Generic, TypeVar
66

77
import typing
88

@@ -11,6 +11,8 @@
1111

1212
from collections import Counter
1313

14+
T = TypeVar("T")
15+
1416
def strtobool(value: str) -> bool:
1517
trueValues = ["true", "yes", "t", "y", "on", "1"]
1618

@@ -54,11 +56,11 @@ def get_args(t):
5456
return tuple()
5557

5658

57-
class DataclassReader:
59+
class DataclassReader(Generic[T]):
5860
def __init__(
5961
self,
6062
f: Any,
61-
cls: Type[object],
63+
cls: Type[T],
6264
fieldnames: Optional[Sequence[str]] = None,
6365
restkey: Optional[str] = None,
6466
restval: Optional[Any] = None,
@@ -192,7 +194,7 @@ def _parse_date_value(self, field, date_value, field_type):
192194
else:
193195
return datetime_obj
194196

195-
def _process_row(self, row):
197+
def _process_row(self, row) -> T:
196198
values = dict()
197199

198200
for field in dataclasses.fields(self._cls):
@@ -251,7 +253,7 @@ def _process_row(self, row):
251253
values[field.name] = transformed_value
252254
return self._cls(**values)
253255

254-
def __next__(self):
256+
def __next__(self) -> T:
255257
row = next(self._reader)
256258
return self._process_row(row)
257259

dataclass_csv/dataclass_reader.pyi

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
from .field_mapper import FieldMapper as FieldMapper
2-
from typing import Any, Optional, Sequence, Type
2+
from typing import Any, Optional, Sequence, Type, Generic, TypeVar
33

4-
class DataclassReader:
4+
T = TypeVar("T")
5+
6+
7+
class DataclassReader(Generic[T]):
58
def __init__(
69
self,
710
f: Any,
8-
cls: Type[object],
11+
cls: Type[T],
912
fieldnames: Optional[Sequence[str]] = ...,
1013
restkey: Optional[str] = ...,
1114
restval: Optional[Any] = ...,
1215
dialect: str = ...,
1316
*args: Any,
1417
**kwds: Any
1518
) -> None: ...
16-
def __next__(self) -> None: ...
19+
def __next__(self) -> T: ...
1720
def __iter__(self) -> Any: ...
1821
def map(self, csv_fieldname: str) -> FieldMapper: ...

0 commit comments

Comments
 (0)