diff --git a/tests/adapters.py b/tests/adapters.py index a955e4226..89020b8f5 100644 --- a/tests/adapters.py +++ b/tests/adapters.py @@ -89,7 +89,7 @@ def run_swiglu( def run_scaled_dot_product_attention( Q: Float[Tensor, " ... queries d_k"], K: Float[Tensor, " ... keys d_k"], - V: Float[Tensor, " ... values d_v"], + V: Float[Tensor, " ... keys d_v"], mask: Bool[Tensor, " ... queries keys"] | None = None, ) -> Float[Tensor, " ... queries d_v"]: """ @@ -99,7 +99,7 @@ def run_scaled_dot_product_attention( Args: Q (Float[Tensor, " ... queries d_k"]): Query tensor K (Float[Tensor, " ... keys d_k"]): Key tensor - V (Float[Tensor, " ... values d_v"]): Values tensor + V (Float[Tensor, " ... keys d_v"]): Values tensor mask (Bool[Tensor, " ... queries keys"] | None): Mask tensor Returns: Float[Tensor, " ... queries d_v"]: Output of SDPA