@@ -431,11 +431,6 @@ def install_guards(self, *guards):
431431 install_guard (* [source .make_guard (guard ) for guard in guards ], skip = 1 )
432432 return {}
433433
434- def set_source_and_track_mutable (self , value , var ):
435- assert isinstance (var , VariableTracker )
436- var .source = self .source
437- return self .tx .output .side_effects .track_mutable (value , var )
438-
439434 @classmethod
440435 def _type_dispatch (cls ):
441436 return cls ._type_dispatch_impl (config .trace_numpy )
@@ -607,7 +602,6 @@ def create_2d_tma_descriptor():
607602 elif CustomizedDictVariable .is_matching_cls_hf (type (value )):
608603 self .install_guards (GuardBuilder .TYPE_MATCH )
609604 result = CustomizedDictVariable .wrap (self , value )
610- result .source = self .source
611605 return self .tx .output .side_effects .track_object_existing (value , result )
612606 elif istype (value , (dict , collections .defaultdict , collections .OrderedDict )):
613607 self .install_guards (GuardBuilder .SEQUENCE_LENGTH )
@@ -671,7 +665,7 @@ def build_key_value(i, k, v):
671665 result , user_cls = type (value ), source = self .source
672666 )
673667
674- return self .set_source_and_track_mutable (value , result )
668+ return self .tx . output . side_effects . track_mutable (value , result )
675669 elif isinstance (value , torch .nn .Module ):
676670 return self .wrap_module (value )
677671 elif ConstantVariable .is_literal (value ): # non-atomic literals
@@ -1137,7 +1131,7 @@ def build_key_value(i, k, v):
11371131 )
11381132 elif RestrictedListSubclassVariable .is_matching_cls (type (value )):
11391133 self .install_guards (GuardBuilder .SEQUENCE_LENGTH )
1140- return self .set_source_and_track_mutable (
1134+ return self .tx . output . side_effects . track_mutable (
11411135 value ,
11421136 RestrictedListSubclassVariable (
11431137 [
@@ -1148,6 +1142,7 @@ def build_key_value(i, k, v):
11481142 ],
11491143 user_cls = type (value ),
11501144 user_cls_source = AttrSource (self .source , "__class__" ),
1145+ source = self .source ,
11511146 ),
11521147 )
11531148 elif TorchScriptObjectVariable .is_matching_cls (type (value )):
@@ -1326,9 +1321,9 @@ def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
13261321 )
13271322 tensor_list_proxy .node .meta ["grapharg" ] = grapharg
13281323
1329- result = BaseListVariable .cls_for_instance (value )(output )
1324+ result = BaseListVariable .cls_for_instance (value )(output , source = self . source )
13301325 if istype (value , (list , collections .deque )):
1331- return self .set_source_and_track_mutable (value , result )
1326+ return self .tx . output . side_effects . track_mutable (value , result )
13321327 return result
13331328
13341329 def wrap_tuple_iterator (self , value : tuple_iterator ):
@@ -1339,11 +1334,8 @@ def wrap_tuple_iterator(self, value: tuple_iterator):
13391334 )
13401335 for i in range (tuple_iterator_len (value ))
13411336 ]
1342- result = TupleIteratorVariable (
1343- output , mutation_type = ValueMutationNew (), source = self .source
1344- )
1345-
1346- return self .set_source_and_track_mutable (value , result )
1337+ result = TupleIteratorVariable (output , source = self .source )
1338+ return self .tx .output .side_effects .track_mutable (value , result )
13471339
13481340 def wrap_range_iterator (self , value : range_iterator ):
13491341 self .install_guards (GuardBuilder .RANGE_ITERATOR_MATCH )
@@ -1512,7 +1504,7 @@ def wrap_literal(self, value):
15121504 self .install_guards (GuardBuilder .CONSTANT_MATCH )
15131505 result = ConstantVariable .create (value = value , source = self .source )
15141506 if isinstance (value , (list , set )):
1515- return self .set_source_and_track_mutable (value , result )
1507+ return self .tx . output . side_effects . track_mutable (value , result )
15161508 return result
15171509
15181510 def assert_not_wrapped_by_this_graph (self , value : torch .Tensor ):
@@ -2403,7 +2395,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
24032395 elif istype (example_value , tuple ):
24042396 return TupleVariable (unpacked , ** options )
24052397 elif istype (example_value , (list , immutable_list )):
2406- return ListVariable (unpacked , mutation_type = ValueMutationNew (), ** options )
2398+ return ListVariable (unpacked , ** options )
24072399 else :
24082400 assert example_value .__class__ .__module__ == "torch.return_types" or hasattr (
24092401 example_value , "_fields"
0 commit comments