1
1
import copy
2
+ import json
2
3
import logging
3
4
from collections import deque
4
5
from collections .abc import Generator
@@ -153,8 +154,8 @@ def reset_copy(self):
153
154
154
155
return new_instance
155
156
156
- def dump_state (self ):
157
- return {name : param .dump_state () for name , param in self .named_parameters ()}
157
+ def dump_state (self , json_mode = True ):
158
+ return {name : param .dump_state (json_mode = json_mode ) for name , param in self .named_parameters ()}
158
159
159
160
def load_state (self , state ):
160
161
for name , param in self .named_parameters ():
@@ -220,9 +221,9 @@ def save(self, path, save_program=False, modules_to_serialize=None):
220
221
221
222
return
222
223
223
- state = self .dump_state ()
224
- state ["metadata" ] = metadata
225
224
if path .suffix == ".json" :
225
+ state = self .dump_state ()
226
+ state ["metadata" ] = metadata
226
227
try :
227
228
with open (path , "w" , encoding = "utf-8" ) as f :
228
229
f .write (orjson .dumps (state , option = orjson .OPT_INDENT_2 | orjson .OPT_APPEND_NEWLINE ).decode ("utf-8" ))
@@ -233,6 +234,8 @@ def save(self, path, save_program=False, modules_to_serialize=None):
233
234
"with `.pkl`, or saving the whole program by setting `save_program=True`."
234
235
)
235
236
elif path .suffix == ".pkl" :
237
+ state = self .dump_state (json_mode = False )
238
+ state ["metadata" ] = metadata
236
239
with open (path , "wb" ) as f :
237
240
cloudpickle .dump (state , f )
238
241
else :
0 commit comments