The `train_step` function uses a potentially inaccurate formula to calculate `zt`: ```python zt = αb*u_y + (1-αb).sqrt()*eps ``` I believe it should be replaced by: ```python zt = αb.sqrt()*u_y + (1-αb).sqrt()*eps ```
The
train_stepfunction uses a potentially inaccurate formula to calculatezt:I believe it should be replaced by: