Good job! But I have some problems with the code. When I type in the example and run it, I get a RuntimeError: "svd_cuda_gesvdjBatched" not implemented for 'Half'.
This error is accompanied by two warnings: UserWarning: cannot import name '_C' from 'sam2' and UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. But I don't think that's a big deal, do I?
I hope you can give me some suggestions