@@ -453,19 +453,18 @@ public Output<?>[] whileLoop(
453
453
synchronized SaverDef saverDef () {
454
454
if (saverDef == null ) {
455
455
// Check to see if this graph has a restore operation
456
- if (operation ("save/restore_all" ) == null ) {
456
+ if (operation (SAVER_DEF_SCOPE + "/" + SAVER_DEF_RESTORE_OP ) == null ) {
457
457
// No saver, create one by mutating the graph
458
458
saverDef = addVariableSaver (this );
459
459
} else {
460
460
// This graph already has saving/restoring operations,
461
- // regenerate SaverDef without mutating. The names mirror
462
- // the python implementation for compatibility.
463
- // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
464
- saverDef = SaverDef .newBuilder ()
465
- .setFilenameTensorName ("save/filename" )
466
- .setSaveTensorName ("save/control_dependency" )
467
- .setRestoreOpName ("save/restore_all" )
468
- .build ();
461
+ // regenerate SaverDef without mutating.
462
+ saverDef =
463
+ SaverDef .newBuilder ()
464
+ .setFilenameTensorName (SAVER_DEF_SCOPE + "/" + SAVER_DEF_FILENAME_OP + ":0" )
465
+ .setSaveTensorName (SAVER_DEF_SCOPE + "/" + SAVER_DEF_SAVE_OP )
466
+ .setRestoreOpName (SAVER_DEF_SCOPE + "/" + SAVER_DEF_RESTORE_OP )
467
+ .build ();
469
468
}
470
469
}
471
470
return saverDef ;
@@ -570,6 +569,13 @@ public void remove() {
570
569
private int position ;
571
570
}
572
571
572
+ // These names mirror the python implementation, to reduce the risk of incompatibility.
573
+ // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
574
+ private static final String SAVER_DEF_SCOPE = "save" ;
575
+ private static final String SAVER_DEF_FILENAME_OP = "filename" ;
576
+ private static final String SAVER_DEF_SAVE_OP = "control_dependency" ;
577
+ private static final String SAVER_DEF_RESTORE_OP = "restore_all" ;
578
+
573
579
private static TF_Graph allocate () {
574
580
return TF_NewGraph ();
575
581
}
@@ -797,7 +803,7 @@ private static Object[] whileLoop(
797
803
}
798
804
799
805
private static SaverDef addVariableSaver (Graph graph ) {
800
- Ops tf = Ops .create (graph ).withSubScope ("save" );
806
+ Ops tf = Ops .create (graph ).withSubScope (SAVER_DEF_SCOPE );
801
807
802
808
List <String > varNames = new ArrayList <>();
803
809
List <Operand <?>> varOutputs = new ArrayList <>();
@@ -812,36 +818,35 @@ private static SaverDef addVariableSaver(Graph graph) {
812
818
}
813
819
}
814
820
815
- // FIXME Need an easier way to initialize an NdArray from a list
816
- String [] tmp = new String [varNames .size ()];
817
- Constant <TString > varNamesTensor = tf .constant (StdArrays .ndCopyOf (varNames .toArray (tmp )));
818
- Operand <TString > varSlices = tf .zerosLike (varNamesTensor );
819
-
820
- Placeholder <TString > saveFilename = tf .withName ("filename" ).placeholder (TString .class );
821
- Save saveVariables = tf .train .save (
822
- saveFilename ,
823
- varNamesTensor ,
824
- varSlices ,
825
- varOutputs
826
- );
827
- Identity <TString > id = tf .withControlDependencies (Arrays .asList (saveFilename ,saveVariables ))
828
- .withName ("control_dependency" ).identity (saveFilename );
829
- Restore restoreVariables = tf .train .restore (
830
- saveFilename ,
831
- varNamesTensor ,
832
- varSlices ,
833
- varTypes
834
- );
835
- List <Op > restoreOps = new ArrayList <>(varOutputs .size ());
836
- for (int i = 0 ; i < varOutputs .size (); ++i ) {
837
- restoreOps .add (tf .assign (varOutputs .get (i ), (Operand ) restoreVariables .tensors ().get (i )));
821
+ Placeholder <TString > filename = tf .withName (SAVER_DEF_FILENAME_OP ).placeholder (TString .class );
822
+ Identity <TString > save = null ;
823
+ NoOp restore = null ;
824
+
825
+ if (varNames .isEmpty ()) {
826
+ save = tf .withName (SAVER_DEF_SAVE_OP ).identity (filename );
827
+ restore = tf .withName (SAVER_DEF_RESTORE_OP ).noOp ();
828
+ } else {
829
+ String [] tmp = new String [varNames .size ()];
830
+ Constant <TString > varNamesTensor = tf .constant (StdArrays .ndCopyOf (varNames .toArray (tmp )));
831
+ Operand <TString > varSlices = tf .zerosLike (varNamesTensor );
832
+ Save saveVars = tf .train .save (filename , varNamesTensor , varSlices , varOutputs );
833
+ List <Op > saveDeps = Arrays .asList (filename , saveVars );
834
+ Restore restoreVars = tf .train .restore (filename , varNamesTensor , varSlices , varTypes );
835
+ List <Op > restoreDeps = new ArrayList <>(varOutputs .size ());
836
+ for (int i = 0 ; i < varOutputs .size (); ++i ) {
837
+ restoreDeps .add (tf .assign (varOutputs .get (i ), (Operand ) restoreVars .tensors ().get (i )));
838
+ }
839
+ save = tf .withControlDependencies (saveDeps ).withName (SAVER_DEF_SAVE_OP ).identity (filename );
840
+ restore = tf .withControlDependencies (restoreDeps ).withName (SAVER_DEF_RESTORE_OP ).noOp ();
838
841
}
839
- NoOp restoreAll = tf .withControlDependencies (restoreOps ).withName ("restore_all" ).noOp ();
840
842
843
+ // 'Filename' must be the name of a tensor (i.e. with output index)
844
+ // 'Save' must be an operation name, even if the field name is confusing (see SaverDef doc)
845
+ // 'Restore' must be an operation name
841
846
return SaverDef .newBuilder ()
842
- .setFilenameTensorName (saveFilename . op ().name ())
843
- .setSaveTensorName (id .op ().name ())
844
- .setRestoreOpName (restoreAll .op ().name ())
847
+ .setFilenameTensorName (filename . output ().name ())
848
+ .setSaveTensorName (save .op ().name ())
849
+ .setRestoreOpName (restore .op ().name ())
845
850
.build ();
846
851
}
847
852
0 commit comments