Kernels
File size: 5,094 Bytes
99af74f
e93bd1e
99af74f
e93bd1e
99af74f
 
 
 
 
 
e93bd1e
99af74f
 
e93bd1e
99af74f
 
 
 
 
 
 
 
f88998f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10848ab
f88998f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10848ab
 
 
 
f88998f
 
 
 
 
 
 
10848ab
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# MIT License
#
# Copyright (c) 2025 Tianyang Lin
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import torch
import triton
import triton.language as tl


def get_autotune_config():
    return [
        triton.Config(
            {
                'BLOCK_SIZE_M': blk_m,
                'BLOCK_SIZE_K': blk_k,
                'GROUP_SIZE_M': grp_sz
            },
            num_stages=n_stages,
            num_warps=n_warps) for blk_m in [32, 64, 128]
        for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
        for n_warps in [4, 8]
    ]


@triton.autotune(
    configs=get_autotune_config(),
    key=['M', 'K'],
    restore_value=['y'],
)
@triton.jit
def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
               BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
               GROUP_SIZE_M: tl.constexpr):
    """
    Core kernel jit function of matmul_transpose that computes y = x @ x.T
    The code is a simple adaptation from the triton `matmul` tutorial:
    https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
    """
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m
    if pid_m > pid_n:
        return

    offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    # we use a & b ptrs to denote different rows of x.
    a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
    b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)

    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs,
                    mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
                    other=0.0)
        b = tl.load(b_ptrs,
                    mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
                    other=0.0)
        accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
        a_ptrs += BLOCK_SIZE_K * stride_xk
        b_ptrs += BLOCK_SIZE_K * stride_xk
    # use dtype.element_ty to accommodate different input datatypes as in cpp templates
    # https://github.com/triton-lang/triton/issues/2252
    c = accumulator.to(x.dtype.element_ty)

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
    tl.store(c_ptrs, c, mask=c_mask)

    # transpose and copy
    if pid_m < pid_n:
        ct_ptrs = y + stride_ym * offs_cn[:,
                                          None] + stride_yn * offs_cm[None, :]
        ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
        tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)


@torch.library.custom_op("muon::matmul_transpose_assign",
                         mutates_args=("d_out", ))
def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
    """Compute d_out = d_in @ d_in.T using an optimized Triton kernel."""
    d_in = d_in.contiguous()
    M, K = d_in.shape
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
        M, META['BLOCK_SIZE_M']), )
    with torch.cuda.device(d_in.device.index):
        mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
                         d_out.stride(0), d_out.stride(1))


@matmul_transpose_assign.register_fake
def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
    """FakeTensor impl: d_out is already allocated, mutation is declared."""
    pass