@@ -841,6 +841,76 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
841841 _launcher(_helion_fp8_gemm, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
842842 return out
843843
844+ --- assertExpectedJournal(TestExamples.test_fused_linear_jsd)
845+ from __future__ import annotations
846+
847+ import torch
848+ import triton
849+ import triton.language as tl
850+ from torch._inductor.runtime.triton_helpers import math as tl_math
851+ from torch._inductor.runtime.triton_compat import libdevice
852+ from helion.runtime import default_launcher as _default_launcher
853+
854+ @triton.jit
855+ def _helion_fused_linear_jsd_kernel(student_logits, teacher_logits, loss, student_logits_size_0, teacher_logits_size_1, loss_stride_0, student_logits_stride_0, student_logits_stride_1, teacher_logits_stride_0, teacher_logits_stride_1, temperature, beta, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
856+ pid_0 = tl.program_id(0)
857+ offset_0 = pid_0 * _BLOCK_SIZE_0
858+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
859+ mask_0 = indices_0 < student_logits_size_0
860+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
861+ mask_1 = indices_1 < teacher_logits_size_1
862+ load = tl.load(student_logits + (indices_0[:, None] * student_logits_stride_0 + indices_1[None, :] * student_logits_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
863+ v_0 = load / temperature
864+ _mask_to = tl.where(mask_0[:, None] & mask_1[None, :], v_0, tl.full([], float('-inf'), tl.float32))
865+ amax = tl.cast(tl.reshape(tl.max(_mask_to, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
866+ v_1 = v_0 - amax
867+ v_2 = libdevice.exp(v_1)
868+ _mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], v_2, tl.full([], 0, tl.float32))
869+ sum_1 = tl.cast(tl.reshape(tl.sum(_mask_to_1, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
870+ v_3 = tl_math.log(sum_1)
871+ v_4 = v_1 - v_3
872+ load_1 = tl.load(teacher_logits + (indices_0[:, None] * teacher_logits_stride_0 + indices_1[None, :] * teacher_logits_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
873+ v_5 = load_1 / temperature
874+ _mask_to_2 = tl.where(mask_0[:, None] & mask_1[None, :], v_5, tl.full([], float('-inf'), tl.float32))
875+ amax_1 = tl.cast(tl.reshape(tl.max(_mask_to_2, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
876+ v_6 = v_5 - amax_1
877+ v_7 = libdevice.exp(v_6)
878+ _mask_to_3 = tl.where(mask_0[:, None] & mask_1[None, :], v_7, tl.full([], 0, tl.float32))
879+ sum_2 = tl.cast(tl.reshape(tl.sum(_mask_to_3, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
880+ v_8 = tl_math.log(sum_2)
881+ v_9 = v_6 - v_8
882+ student_prob_1 = tl.reshape(v_4, [_BLOCK_SIZE_0, _RDIM_SIZE_1])
883+ teacher_prob_1 = tl.reshape(v_9, [_BLOCK_SIZE_0, _RDIM_SIZE_1])
884+ v_10 = libdevice.exp(student_prob_1)
885+ v_11 = libdevice.exp(teacher_prob_1)
886+ v_12 = libdevice.exp(student_prob_1)
887+ v_13 = v_11 - v_12
888+ v_14 = v_13 * beta
889+ v_15 = v_10 + v_14
890+ v_16 = tl_math.log(v_15)
891+ v_17 = teacher_prob_1 - v_16
892+ v_18 = libdevice.exp(teacher_prob_1)
893+ v_19 = v_18 * v_17
894+ _mask_to_4 = tl.where(mask_0[:, None] & mask_1[None, :], v_19, tl.full([], 0, tl.float32))
895+ teacher_div = tl.cast(tl.sum(_mask_to_4, 1), tl.float32)
896+ v_20 = tl_math.log(v_15)
897+ v_21 = student_prob_1 - v_20
898+ v_22 = libdevice.exp(student_prob_1)
899+ v_23 = v_22 * v_21
900+ _mask_to_5 = tl.where(mask_0[:, None] & mask_1[None, :], v_23, tl.full([], 0, tl.float32))
901+ student_div = tl.cast(tl.sum(_mask_to_5, 1), tl.float32)
902+ v_24 = teacher_div - student_div
903+ v_25 = v_24 * beta
904+ v_26 = student_div + v_25
905+ tl.store(loss + indices_0 * loss_stride_0, v_26, mask_0)
906+
907+ def fused_linear_jsd_kernel(beta: float, ignore_index: int, temperature: float, student_logits: torch.Tensor, teacher_logits: torch.Tensor, *, _launcher=_default_launcher):
908+ loss = student_logits.new_empty(student_logits.shape[0], dtype=torch.float)
909+ _BLOCK_SIZE_0 = 32
910+ _RDIM_SIZE_1 = triton.next_power_of_2(teacher_logits.size(1))
911+ _launcher(_helion_fused_linear_jsd_kernel, (triton.cdiv(student_logits.size(0), _BLOCK_SIZE_0),), student_logits, teacher_logits, loss, student_logits.size(0), teacher_logits.size(1), loss.stride(0), student_logits.stride(0), student_logits.stride(1), teacher_logits.stride(0), teacher_logits.stride(1), temperature, beta, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
912+ return (loss / student_logits.shape[0]).sum()
913+
844914--- assertExpectedJournal(TestExamples.test_gather_gemv)
845915from __future__ import annotations
846916
0 commit comments