Skip to content

Commit c7ad9f4

Browse files
add json mode
1 parent 5f2ca85 commit c7ad9f4

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

dspy/predict/predict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def reset(self):
3030
self.train = []
3131
self.demos = []
3232

33-
def dump_state(self):
33+
def dump_state(self, json_mode=True):
3434
state_keys = ["traces", "train"]
3535
state = {k: getattr(self, k) for k in state_keys}
3636

@@ -42,7 +42,7 @@ def dump_state(self):
4242
# FIXME: Saving BaseModels as strings in examples doesn't matter because you never re-access as an object
4343
demo[field] = serialize_object(demo[field])
4444

45-
if isinstance(demo, dict):
45+
if isinstance(demo, dict) or not json_mode:
4646
state["demos"].append(demo)
4747
else:
4848
state["demos"].append(demo.toDict())

dspy/primitives/base_module.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import json
23
import logging
34
from collections import deque
45
from collections.abc import Generator
@@ -153,8 +154,8 @@ def reset_copy(self):
153154

154155
return new_instance
155156

156-
def dump_state(self):
157-
return {name: param.dump_state() for name, param in self.named_parameters()}
157+
def dump_state(self, json_mode=True):
158+
return {name: param.dump_state(json_mode=json_mode) for name, param in self.named_parameters()}
158159

159160
def load_state(self, state):
160161
for name, param in self.named_parameters():
@@ -220,9 +221,9 @@ def save(self, path, save_program=False, modules_to_serialize=None):
220221

221222
return
222223

223-
state = self.dump_state()
224-
state["metadata"] = metadata
225224
if path.suffix == ".json":
225+
state = self.dump_state()
226+
state["metadata"] = metadata
226227
try:
227228
with open(path, "w", encoding="utf-8") as f:
228229
f.write(orjson.dumps(state, option=orjson.OPT_INDENT_2 | orjson.OPT_APPEND_NEWLINE).decode("utf-8"))
@@ -233,6 +234,8 @@ def save(self, path, save_program=False, modules_to_serialize=None):
233234
"with `.pkl`, or saving the whole program by setting `save_program=True`."
234235
)
235236
elif path.suffix == ".pkl":
237+
state = self.dump_state(json_mode=False)
238+
state["metadata"] = metadata
236239
with open(path, "wb") as f:
237240
cloudpickle.dump(state, f)
238241
else:

0 commit comments

Comments
 (0)