@@ -379,7 +379,7 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
379379 This util applies to any factory function that takes 'size' as the first argument,
380380 and supports Replication and Shard placements all at zero cost.
381381 """
382- assert isinstance (op_schema .args_schema [0 ], torch .Size )
382+ assert isinstance (op_schema .args_schema [0 ], ( torch .Size , list ) )
383383 shape = op_schema .args_schema [0 ]
384384 x = torch .empty (shape , device = "meta" )
385385 stride = x .stride ()
@@ -408,7 +408,7 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
408408 for strategy_comb in strategy_combs :
409409 spec_list = [DTensorSpec (mesh , specs ) for specs in zip (* strategy_comb )]
410410 output_specs = spec_list [0 ]
411- output_specs .tensor_meta = TensorMeta (shape , stride , dtype )
411+ output_specs .tensor_meta = TensorMeta (shape , stride , dtype ) # type: ignore[arg-type]
412412
413413 if not is_tensor_shardable (shape , output_specs ):
414414 continue
@@ -421,8 +421,13 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
421421 * len (strategy_combs )
422422 ]
423423
424+ # NOTE: why do we have input_specs for constructor nodes, given that they have no inputs?
425+ # This is because the optimizer code expects to see input_specs for all nodes, and it
426+ # uses the input_specs to determine the sharding of the output. So we have to give it
427+ # something, even though it is in principle not needed.
424428 strategy = OpSpec (
425429 output_specs = output_specs ,
430+ input_specs = [output_specs ],
426431 redistribute_cost = redistribute_cost ,
427432 )
428433 all_strategies .append (strategy )
0 commit comments