File tree Expand file tree Collapse file tree 3 files changed +11
-8
lines changed Expand file tree Collapse file tree 3 files changed +11
-8
lines changed Original file line number Diff line number Diff line change @@ -116,6 +116,9 @@ def plan_spec(
116
116
Greedily place the spec in the first memory that can fit it.
117
117
"""
118
118
for spec .mem_id in range (1 , self .get_num_memories ()):
119
+ if placement_constraints .is_mem_id_in_blocklist (spec , spec .mem_id ):
120
+ # Skip placement for blocked memory id.
121
+ continue
119
122
prev_offset , smallest_gap = 0 , float ("inf" )
120
123
for allocated_spec in state .allocated_buffers [spec .mem_id ]:
121
124
if not Verifier .lifetime_overlap (spec , allocated_spec ):
@@ -141,11 +144,11 @@ def plan_spec(
141
144
)
142
145
if spec .mem_offset is None :
143
146
spec .mem_offset = prev_offset
144
- if not self . is_valid_placement ( spec , placement_constraints ):
145
- spec . mem_offset = None
146
- continue
147
- else :
148
- spec . mem_offset = prev_offset
147
+
148
+ if not self . is_valid_placement ( spec , placement_constraints ):
149
+ # Skip placement for invalid memory id.
150
+ spec . mem_offset = None
151
+ continue
149
152
150
153
state .place_spec (spec )
151
154
# A data structure used for maintaining the tensor order
Original file line number Diff line number Diff line change @@ -204,7 +204,7 @@ def _place_memory_id_pinned_specs(
204
204
for spec , c in spec_with_abs_constraint .items ()
205
205
if c is not None and c .pinned_memory_id == mem_id and c .offset is None
206
206
}
207
- logging .error (f"Placing specs { mem_id_pinned_specs } for { mem_id = } " )
207
+ logging .debug (f"Placing specs { mem_id_pinned_specs } for { mem_id = } " )
208
208
209
209
with self .block_memories_except (mem_id ):
210
210
self .plan (
@@ -220,7 +220,7 @@ def _place_memory_id_pinned_specs(
220
220
if constraint is None :
221
221
continue
222
222
223
- logging .error (f"Placing spec { spec } with { constraint } " )
223
+ logging .debug (f"Placing spec { spec } with { constraint } " )
224
224
225
225
if not state .is_placed (spec ):
226
226
raise MemoryError (
Original file line number Diff line number Diff line change @@ -1044,7 +1044,7 @@ class DummyMemIdBlockConstraintGen(PassBase):
1044
1044
mul: blocks 1, 3
1045
1045
"""
1046
1046
1047
- def __init__ (self , memory_constraints : MemoryConfig ):
1047
+ def __init__ (self , memory_constraints : MemConstraints ):
1048
1048
self .memory_constraints = memory_constraints
1049
1049
1050
1050
def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
You can’t perform that action at this time.
0 commit comments