Skip to content

Commit da482fd

Browse files
Replace ujson by orjson (#8655)
* port to orjson from ujson * init * some fixes * fix demo serialization * add json mode * remove decoding * fix nested example * remove redundant decoding --------- Co-authored-by: Aadya Chinubhai <aadyachinubhai@gmail.com>
1 parent 5d8b080 commit da482fd

File tree

14 files changed

+145
-172
lines changed

14 files changed

+145
-172
lines changed

dspy/clients/cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from typing import Any
88

99
import cloudpickle
10+
import orjson
1011
import pydantic
11-
import ujson
1212
from cachetools import LRUCache
1313
from diskcache import FanoutCache
1414

@@ -93,7 +93,7 @@ def transform_value(value):
9393
return value
9494

9595
params = {k: transform_value(v) for k, v in request.items() if k not in ignored_args_for_cache_key}
96-
return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest()
96+
return sha256(orjson.dumps(params, option=orjson.OPT_SORT_KEYS)).hexdigest()
9797

9898
def get(self, request: dict[str, Any], ignored_args_for_cache_key: list[str] | None = None) -> Any:
9999
try:

dspy/clients/databricks.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import time
55
from typing import TYPE_CHECKING, Any
66

7+
import orjson
78
import requests
8-
import ujson
99

1010
from dspy.clients.provider import Provider, TrainingJob
1111
from dspy.clients.utils_finetune import TrainDataFormat, get_finetune_directory
@@ -265,8 +265,7 @@ def _get_workspace_client() -> "WorkspaceClient":
265265
from databricks.sdk import WorkspaceClient
266266
except ImportError:
267267
raise ImportError(
268-
"To use Databricks finetuning, please install the databricks-sdk package via "
269-
"`pip install databricks-sdk`."
268+
"To use Databricks finetuning, please install the databricks-sdk package via `pip install databricks-sdk`."
270269
)
271270
return WorkspaceClient()
272271

@@ -311,14 +310,14 @@ def _save_data_to_local_file(train_data: list[dict[str, Any]], data_format: Trai
311310
finetune_dir = get_finetune_directory()
312311
file_path = os.path.join(finetune_dir, file_name)
313312
file_path = os.path.abspath(file_path)
314-
with open(file_path, "w") as f:
313+
with open(file_path, "wb") as f:
315314
for item in train_data:
316315
if data_format == TrainDataFormat.CHAT:
317316
_validate_chat_data(item)
318317
elif data_format == TrainDataFormat.COMPLETION:
319318
_validate_completion_data(item)
320319

321-
f.write(ujson.dumps(item) + "\n")
320+
f.write(orjson.dumps(item) + b"\n")
322321
return file_path
323322

324323

dspy/clients/utils_finetune.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from enum import Enum
33
from typing import Any, Literal, TypedDict
44

5-
import ujson
5+
import orjson
66

77
import dspy
88
from dspy.adapters.base import Adapter
@@ -58,9 +58,9 @@ def get_finetune_directory() -> str:
5858

5959

6060
def write_lines(file_path, data):
61-
with open(file_path, "w") as f:
61+
with open(file_path, "wb") as f:
6262
for item in data:
63-
f.write(ujson.dumps(item) + "\n")
63+
f.write(orjson.dumps(item) + b"\n")
6464

6565

6666
def save_data(
@@ -75,9 +75,9 @@ def save_data(
7575
finetune_dir = get_finetune_directory()
7676
file_path = os.path.join(finetune_dir, file_name)
7777
file_path = os.path.abspath(file_path)
78-
with open(file_path, "w") as f:
78+
with open(file_path, "wb") as f:
7979
for item in data:
80-
f.write(ujson.dumps(item) + "\n")
80+
f.write(orjson.dumps(item) + b"\n")
8181
return file_path
8282

8383

dspy/predict/predict.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def reset(self):
3030
self.train = []
3131
self.demos = []
3232

33-
def dump_state(self):
33+
def dump_state(self, json_mode=True):
3434
state_keys = ["traces", "train"]
3535
state = {k: getattr(self, k) for k in state_keys}
3636

@@ -42,7 +42,10 @@ def dump_state(self):
4242
# FIXME: Saving BaseModels as strings in examples doesn't matter because you never re-access as an object
4343
demo[field] = serialize_object(demo[field])
4444

45-
state["demos"].append(demo)
45+
if isinstance(demo, dict) or not json_mode:
46+
state["demos"].append(demo)
47+
else:
48+
state["demos"].append(demo.toDict())
4649

4750
state["signature"] = self.signature.dump_state()
4851
state["lm"] = self.lm.dump_state() if self.lm else None

dspy/predict/refine.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import textwrap
33
from typing import Callable
44

5-
import ujson
5+
import orjson
66

77
import dspy
88
from dspy.adapters.utils import get_field_description_string
@@ -158,10 +158,9 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs):
158158
}
159159

160160
advise_kwargs = dict(**modules, **trajectory, **reward, module_names=module_names)
161-
# advise_kwargs = {k: ujson.dumps(recursive_mask(v), indent=2) for k, v in advise_kwargs.items()}
162161
# only dumps if it's a list or dict
163162
advise_kwargs = {
164-
k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2)
163+
k: v if isinstance(v, str) else orjson.dumps(recursive_mask(v), option=orjson.OPT_INDENT_2).decode()
165164
for k, v in advise_kwargs.items()
166165
}
167166
advice = dspy.Predict(OfferFeedback)(**advise_kwargs).advice
@@ -200,7 +199,7 @@ def inspect_modules(program):
200199
def recursive_mask(o):
201200
# If the object is already serializable, return it.
202201
try:
203-
ujson.dumps(o)
202+
orjson.dumps(o)
204203
return o
205204
except TypeError:
206205
pass

dspy/primitives/base_module.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pathlib import Path
66

77
import cloudpickle
8-
import ujson
8+
import orjson
99

1010
from dspy.utils.saving import get_dependency_versions
1111

@@ -153,8 +153,8 @@ def reset_copy(self):
153153

154154
return new_instance
155155

156-
def dump_state(self):
157-
return {name: param.dump_state() for name, param in self.named_parameters()}
156+
def dump_state(self, json_mode=True):
157+
return {name: param.dump_state(json_mode=json_mode) for name, param in self.named_parameters()}
158158

159159
def load_state(self, state):
160160
for name, param in self.named_parameters():
@@ -169,10 +169,10 @@ def save(self, path, save_program=False, modules_to_serialize=None):
169169
- `save_program=True`: Save the whole module to a directory via cloudpickle, which contains both the state and
170170
architecture of the model.
171171
172-
If `save_program=True` and `modules_to_serialize` are provided, it will register those modules for serialization
173-
with cloudpickle's `register_pickle_by_value`. This causes cloudpickle to serialize the module by value rather
174-
than by reference, ensuring the module is fully preserved along with the saved program. This is useful
175-
when you have custom modules that need to be serialized alongside your program. If None, then no modules
172+
If `save_program=True` and `modules_to_serialize` are provided, it will register those modules for serialization
173+
with cloudpickle's `register_pickle_by_value`. This causes cloudpickle to serialize the module by value rather
174+
than by reference, ensuring the module is fully preserved along with the saved program. This is useful
175+
when you have custom modules that need to be serialized alongside your program. If None, then no modules
176176
will be registered for serialization.
177177
178178
We also save the dependency versions, so that the loaded model can check if there is a version mismatch on
@@ -215,24 +215,26 @@ def save(self, path, save_program=False, modules_to_serialize=None):
215215
f"Saving failed with error: {e}. Please remove the non-picklable attributes from your DSPy program, "
216216
"or consider using state-only saving by setting `save_program=False`."
217217
)
218-
with open(path / "metadata.json", "w", encoding="utf-8") as f:
219-
ujson.dump(metadata, f, indent=2, ensure_ascii=False)
218+
with open(path / "metadata.json", "wb") as f:
219+
f.write(orjson.dumps(metadata, option=orjson.OPT_INDENT_2 | orjson.OPT_APPEND_NEWLINE))
220220

221221
return
222222

223-
state = self.dump_state()
224-
state["metadata"] = metadata
225223
if path.suffix == ".json":
224+
state = self.dump_state()
225+
state["metadata"] = metadata
226226
try:
227-
with open(path, "w", encoding="utf-8") as f:
228-
f.write(ujson.dumps(state, indent=2 , ensure_ascii=False))
227+
with open(path, "wb") as f:
228+
f.write(orjson.dumps(state, option=orjson.OPT_INDENT_2 | orjson.OPT_APPEND_NEWLINE))
229229
except Exception as e:
230230
raise RuntimeError(
231231
f"Failed to save state to {path} with error: {e}. Your DSPy program may contain non "
232232
"json-serializable objects, please consider saving the state in .pkl by using `path` ending "
233233
"with `.pkl`, or saving the whole program by setting `save_program=True`."
234234
)
235235
elif path.suffix == ".pkl":
236+
state = self.dump_state(json_mode=False)
237+
state["metadata"] = metadata
236238
with open(path, "wb") as f:
237239
cloudpickle.dump(state, f)
238240
else:
@@ -248,8 +250,8 @@ def load(self, path):
248250
path = Path(path)
249251

250252
if path.suffix == ".json":
251-
with open(path, encoding="utf-8") as f:
252-
state = ujson.loads(f.read())
253+
with open(path, "rb") as f:
254+
state = orjson.loads(f.read())
253255
elif path.suffix == ".pkl":
254256
with open(path, "rb") as f:
255257
state = cloudpickle.load(f)

dspy/primitives/example.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,18 @@ def without(self, *keys):
105105
return copied
106106

107107
def toDict(self): # noqa: N802
108-
return self._store.copy()
108+
def convert_to_serializable(value):
109+
if hasattr(value, "toDict"):
110+
return value.toDict()
111+
elif isinstance(value, list):
112+
return [convert_to_serializable(item) for item in value]
113+
elif isinstance(value, dict):
114+
return {k: convert_to_serializable(v) for k, v in value.items()}
115+
else:
116+
return value
117+
118+
serializable_store = {}
119+
for k, v in self._store.items():
120+
serializable_store[k] = convert_to_serializable(v)
121+
122+
return serializable_store

dspy/streaming/streamify.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator
88

99
import litellm
10-
import ujson
10+
import orjson
1111
from anyio import create_memory_object_stream, create_task_group
1212
from anyio.streams.memory import MemoryObjectSendStream
1313
from litellm import ModelResponseStream
@@ -266,10 +266,10 @@ async def streaming_response(streamer: AsyncGenerator) -> AsyncGenerator:
266266
async for value in streamer:
267267
if isinstance(value, Prediction):
268268
data = {"prediction": dict(value.items(include_dspy=False))}
269-
yield f"data: {ujson.dumps(data)}\n\n"
269+
yield f"data: {orjson.dumps(data).decode()}\n\n"
270270
elif isinstance(value, litellm.ModelResponseStream):
271271
data = {"chunk": value.json()}
272-
yield f"data: {ujson.dumps(data)}\n\n"
272+
yield f"data: {orjson.dumps(data).decode()}\n\n"
273273
elif isinstance(value, str) and value.startswith("data:"):
274274
# The chunk value is an OpenAI-compatible streaming chunk value,
275275
# e.g. "data: {"finish_reason": "stop", "index": 0, "is_finished": True, ...}",

dspy/teleprompt/simba_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import textwrap
44
from typing import Callable
55

6-
import ujson
6+
import orjson
77

88
import dspy
99
from dspy.adapters.utils import get_field_description_string
@@ -120,7 +120,7 @@ def append_a_rule(bucket, system, **kwargs):
120120
"module_names": module_names,
121121
}
122122

123-
kwargs = {k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2)
123+
kwargs = {k: v if isinstance(v, str) else orjson.dumps(recursive_mask(v), option=orjson.OPT_INDENT_2).decode()
124124
for k, v in kwargs.items()}
125125
advice = dspy.Predict(OfferFeedback)(**kwargs).module_advice
126126

@@ -194,9 +194,9 @@ def inspect_modules(program):
194194
def recursive_mask(o):
195195
# If the object is already serializable, return it.
196196
try:
197-
ujson.dumps(o)
197+
orjson.dumps(o)
198198
return o
199-
except TypeError:
199+
except (TypeError, orjson.JSONEncodeError):
200200
pass
201201

202202
# If it's a dictionary, apply recursively to its values.

dspy/utils/saving.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import TYPE_CHECKING
55

66
import cloudpickle
7-
import ujson
7+
import orjson
88

99
if TYPE_CHECKING:
1010
from dspy.primitives.module import Module
@@ -40,7 +40,7 @@ def load(path: str) -> "Module":
4040
raise FileNotFoundError(f"The path '{path}' does not exist.")
4141

4242
with open(path / "metadata.json") as f:
43-
metadata = ujson.load(f)
43+
metadata = orjson.loads(f.read())
4444

4545
dependency_versions = get_dependency_versions()
4646
saved_dependency_versions = metadata["dependency_versions"]

0 commit comments

Comments
 (0)