diff --git a/README.md b/README.md index 202a6c9..26f1768 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__v stage_3.enable_model_cpu_offload() -def set_scheduler(stage): +def set_scheduler(stage, scheduler_name): if scheduler_name == 'dpm++': scheduler = DPMSolverMultistepScheduler.from_config(stage.scheduler.config) scheduler.config.algorithm_type = 'dpmsolver++'