Triton Puzzles(Demo1-4) Triton Puzzles之前做tilelang puzzles的时候发现readme里提到是仿照triton puzzles的但当时感觉triton没有学的必要就没做最近发现triton的设计思想和tilelang差异很大感觉可以开拓一下视野就找到这个https://github.com/SiriusNEO/Triton-Puzzles-Lite项目看看这是改进过的轻量版不是原版triton puzzles题目内容没变只是减少了依赖原本的可视化和jupyter notebook都去掉了就在.py文件运行并且附上了作者写的答案可以对比学习。环境需要别的可能也行但是作者建议这个这个是肯定可以跑通的高版本可能报错。pipinstalltorch2.5.0# Check triton version: triton3.1.0安装时如果能访问外网实测最快的是直接用pytorch库其他源都可能会把torch下载限速因为下载的人太多了还大可能把带宽占满了然后注意这里的cu124根据显卡的驱动版本来安装可以先检查cuda驱动版本。然后选一个低于驱动版本的torch whl这种库都是可以向下兼容但不能向上兼容。python3-mpipinstalltorch2.5.0 --index-url https://download.pytorch.org/whl/cu124 --no-cache-dir--isolated推荐使用2.0.0的numpy结果正确性验证时会用numpycheck脚本用了低版本接口版本高了会出错。python3-mpipinstallnumpy1.26.4--isolated运行时设置环境变量1表示用cpu模式py解释器运行0则是gpu模式。gpu模式由于显卡版本不同可能出现各种bug推荐先cpu模式跑通这也是原始triton puzzles的推荐运行方式。当前仓库的答案gpu模式下case 11会运行出错。TRITON_INTERPRET1python3 puzzles.py-a最后的参数部分-a#运行全部puzzles-px#运行第x个puzzle-i#运行四个demo-h#显示帮助文档clone下来可以先跑以下指令验证所有答案cpu模式下是不是都能跑通能的话说明基础环境配置没问题。TRITON_INTERPRET1python3 puzzles_ans.py-aTriton简介Triton 是由 OpenAI 开源的一种专为深度学习加速设计的编程语言和编译器。如果你写过 CUDA你可能会觉得它太底层、开发周期太长如果你只用 PyTorch你可能会发现很多自定义的算子比如各种新型的 Attention 或量化算子无法获得极致的性能。Triton 的诞生正是为了在“开发效率”与“极致性能”之间取得完美的平衡。1. Triton 解决的核心痛点在传统的 GPU 算子开发中通常面临两极分化高端玩家写 CUDA C 可以手动控制线程块、共享内存Shared Memory和寄存器性能毁天灭地但开发极其痛苦且代码很难跨硬件比如从 NVIDIA 转到 AMD复用。普通玩家写 PyTorch/TensorFlow 拼凑现有的 API如 torch.relu torch.matmul开发极快但会在显存中产生大量中间变量造成频繁的显存读写Memory Bound浪费算力。Triton 的核心思想是让没有 CUDA 经验的深度学习研究员也能用类似 Python 的语法写出性能媲美甚至超越专家级 CUDA 的硬件加速算子。2. Triton 的核心设计理念基于块Block-based的编程这也是 Triton 与 CUDA 最本质的区别CUDA 是“基于线程Thread-based”的 你需要精确计算每个 Thread 的 ID去算它该读哪一个具体的显存地址还要手动处理线程之间的同步__syncthreads()和数据共享。Triton 是“基于块Block-based”的 它把张量块Block作为一等公民First-class citizen。你不需要操心单个线程而是直接对一个分块进行加载tl.load、计算tl.dot和存储tl.store。并且triton除了是基于数据块的编程还是声明式编程而不是CUDA的过程式编程也就是你只用写要对这个数据块做什么而不需要写怎么做编译器会把做什么转化成怎么做的机器码。3. 编译器在幕后做了什么既然写起来像 Python 一样简单那极致的性能是怎么来的这全靠 Triton 编译器。它会把你的 Python 风格代码编译成高效的机器码通过 LLVM IR 到 PTX/AMDGCN自动帮你做好以下最头疼的硬件优化自动内存合并Memory Coalescing 自动优化全局显存Global Memory的访问模式确保带宽跑满。自动管理共享内存Shared Memory Allocation 你不需要像写 CUDA 那样手动声明shared数组编译器会自己决定什么时候把数据缓存在片上高速存储里。指令流水线与排程Instruction Scheduling 自动隐藏访存延迟让计算单元Tensor Cores和访存单元能够高效并发。注意这里和tilelang的设计思想不同并不会先映射到CUDA代码再编译。而是自定义了TTIRTriton IR生成TTIR后下一步就会映射到PTX、SaaS代码了不会经过CUDA也就是triton可以被视为一个独立的语言有自己的编译路径而不是CUDA语法糖。4. 谁在用 Triton如今 Triton 已经成为大模型时代基础设施的绝对主力PyTorch 2.0 的核心 PyTorch 2.0 引入的重磅编译功能 TorchInductor其后端默认就是将 PyTorch 代码自动生成为 Triton 内核这也是其实现图编译加速的秘密武器。FlashAttention 著名的闪电注意力机制其后续的很多高效变体和工程实现如 FlashAttention-3都大量采用了 Triton 进行快速迭代。大模型推理加速 比如 vLLM、DeepSpeed 以及各类轻量级量化插件里面普遍包含大量用 Triton 编写的定制化算子如上面我们聊到的量化 GEMM。如果你想深入 AI 芯片底层硬件加速或者想为自己的大模型设计专属的奇门遁甲算子Triton 是目前投产比ROI最高、最值得学习的技术。Demosdemo 1数据搬运是GPU编程中最核心的概念第一个示例主要熟悉tl.load搬运数据tl.load(ptr, mask)参数是两个张量ptr是一个指针数组表示数据搬运源地址数组内每个指针对应一个要搬运的元素。mask是一个掩码数组数据类型是bool用0/1表示ptr数组中传入的每个指针是否搬运。需要额外引入mask的原因是triton里的所有张量数据块的大小都是二的幂次如果我们想灵活搬运一个大小不对齐的张量时比如大小5可以传入一个刚好大于这个张量大小的指针数组长度对齐2的幂次然后用mask来约束搬运范围比如mask就是[1,1,1,1,1,0,0,0]表示前五个位置利用指针地址搬运后三个位置不进行操作。需要注意的是这里传入的x_ptr已经不是torch tensor了而是底层数据的首地址类似c的数组首地址指针这也是命名上带一个ptr的原因因此我们传入指针ptr数组和mask需要人为避免越界如果x_ptr对应的tensor只有八个元素那么就不能访问大于8的位置否则会运行错误或者读到垃圾值。编译器不会阻止你编译时的思路是类C的允许你直接用指针寻址。如果指针数组大小超过tensor了但是mask限制了读取范围不会出问题因为mask为0的位置不会真的去读内存而是直接返回一个值表示不操作可以在tl.load(ptr, mask,0)操作时传入第三个参数表示mask为0的位置填充什么值如果不传入第三个参数默认填充0定义讲完了来看这个算子的具体事项。range tl.arange(0, 8)类似torch.arrange生成一个公差为1的等差数列左闭右开。x tl.load(x_ptr range, range 5, 0)这一行有很多看点。x_ptr range这里的x_ptr本身是一个指针也就是一个标量但是range是刚才生成的数据块两者相加这里triton规定遵循torch/numpy的广播规则把标量广播到和张量一样的shape再执行相加。也就是此时形成了一个[x_ptr,x_ptr1,x_ptr2,...,x_ptr7]的指针数组接下来会去这个数组内的位置搬运数据。range 5类似5是一个标量会广播到和range一样大然后操作会返回一个bool数组用这个方式就构造了一个[1 1 1 1 1 0 0 0]的maskx tl.load(x_ptr range, range 5, 0)最后load返回的是一个triton数据块需要把它复制给一个变量保存下来。demo1[(1, 1, 1)](torch.ones(4, 3))最后是triton内核的启动方法triton设计时DSL还没这么多很多设计师对齐CUDA比如这里(1, 1, 1)就是CUDA启动时传入的launch参数dim3表示grid shape或者说三个维度的block个数。传递给函数的直接参数则在后面圆括号内这里传入一个二维张量(torch.ones(4, 3))。可能会好奇这里传入的是二维张量但kernel内看起来是把他当成一维数组用的这也是类C设计带来的CUDA编程时多维数组不管几维都是当成一维数组使用用的时候再多次寻址实现多维数组的效果triton继承了这一点这个张量4*312个元素在triton kernel内会看成一个长度12的连续内存。r ## Introduction To begin with, we will only use tl.load and tl.store in order to build simple programs. ### Demo 1 Heres an example of load. It takes an arange over the memory. By default the indexing of torch tensors with column, rows, depths or right-to-left. It also takes in a mask as the second argument. Mask is critically important because all shapes in Triton need to be powers of two. Expected Results: [0 1 2 3 4 5 6 7] [1. 1. 1. 1. 1. 0. 0. 0.] Explanation: tl.load(ptr, mask) tl.load use mask: [0 1 2 3 4 5 6 7] 5 [1 1 1 1 1 0 0 0] triton.jitdefdemo1(x_ptr):rangetl.arange(0,8)# print works in the interpreterprint(range)xtl.load(x_ptrrange,range5,0)print(x)defrun_demo1():print(Demo1 Output: )demo1[(1,1,1)](torch.ones(4,3))print_end_line()demo 2仍然是load只是这次需要load一个复杂一点的二维区域i 4 and j 3那么用一个range mask就有点难做到了可以用两个。首先构造两个等差数列一个对应行一个对应列。然后给他们升维类似torch.unsqueeze弄完之后两个mask的shape分别是(8,1)(1,4)i_rangetl.arange(0,8)[:,None]j_rangetl.arange(0,4)[None,:]range i_range * 4 j_range让这两个mask做加法遵循torch/numpy广播规则会都先变成(8,4)再执行加法。并且加之前先把行张量乘上每一行的元素个数这样最后得到的结果每个位置的值都等于把这个张量展开到一维后这个位置的编号可以用来构造mask数组了(i_range 4) (j_range 3)构造mask时可以把两个条件取and这里重载了的规则不是py里的按位与而是表示and。这样我们就限制了只拷贝i 4 and j 3的区域 ### Demo 2: You can also use this trick to read in a 2d array. Expected Results: [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15] [16 17 18 19] [20 21 22 23] [24 25 26 27] [28 29 30 31]] [[1. 1. 1. 0.] [1. 1. 1. 0.] [1. 1. 1. 0.] [1. 1. 1. 0.] [0. 0. 0. 0.] [0. 0. 0. 0.] [0. 0. 0. 0.] [0. 0. 0. 0.]] Explanation: tl.load use mask: i 4 and j 3. triton.jitdefdemo2(x_ptr):i_rangetl.arange(0,8)[:,None]j_rangetl.arange(0,4)[None,:]rangei_range*4j_range# print works in the interpreterprint(range)xtl.load(x_ptrrange,(i_range4)(j_range3),0)print(x)defrun_demo2():print(Demo2 Output: )demo2[(1,1,1)](torch.ones(4,4))print_end_line()demo 3这节主要是学习tl.store写入操作和读取tl.load一起构成了完整的数据搬运。tl.store(ptr, value, mask)参数和tl.load类似也是传入一个指针数组一个mask只不过这是个无返回值的函数所以ptr就是目的地址源则是value。ptr类似前面load的规则类C的指针数组手动寻址。但value是类似py张量可以传入一个标量进行广播也可以传入一个前面load进来的张量不能传入和ptr类似的指针数组也就是源不是给传指针寻址而是直接给出值。一般的范式是读取到一个张量做想做的操作然后再写入也就是读取写入之间一定有一个张量来倒手。xtl.load(x_ptr,mask)tl.store(y_ptr,x,mask)来看具体实现z tl.store(z_ptr range, 10, range 5)这里用z接受了返回值其实是一个陷阱tl.store无返回值所以尝试print(z)会报错。想看结果数据已经被写入z_ptr为首地址的张量了在kernel内只有首地址指针没有z_ptr对应的张量对象看不了必须从kernel里返回后host侧才能看。 ### Demo 3 The tl.store function is quite similar. It allows you to write to a tensor. Expected Results: tensor([[10., 10., 10.], [10., 10., 1.], [ 1., 1., 1.], [ 1., 1., 1.]]) Explanation: tl.store(ptr, value, mask) here range 5 corresponds to the 2D-mask [[1. 1. 1.] [1. 1. 0.] [0. 0. 0.] [0. 0. 0.]] triton.jitdefdemo3(z_ptr):rangetl.arange(0,8)ztl.store(z_ptrrange,10,range5)defrun_demo3():print(Demo3 Output: )ztorch.ones(4,3)demo3[(1,1,1)](z)print(z)print_end_line()demo 4前三个都是单线程的但作为GPU编程当然可以根据数据块编号不同做不同的操作这节来看如何利用tl.program_id确定所在块号然后执行不同操作。tl.program_id(0)这里的012分别是取出这个数据块的三个维度编号三个维度是我们启动内核时传入的比如这里就是demo4[(3, 1, 1)](x)表示0维度长度3另外1,2维度长度1也就是有3 * 1 * 1 3个block。x torch.ones(2, 4, 4)传入的张量展平后有32个元素想要搬运前20个。均分给三个block实现考虑到每次搬运操作的长度都是二的幂次最少的搬运方式是每个block搬8个元素前两个block都全搬最后一个block只用搬前四个设一个mask实现这一点。kernel内range tl.arange(0, 8) pid * 8实现了每个block搬运的位置不同也就是根据block id进行偏移。每个都搬长度为8的区间所以生成一个长度8的等差数列然后累加上块偏移就是这个块负责的地址范围range 20为了只搬前20个增加一个mask限制这个限制只会让最后一个block的mask是前四个1后四个0对前两个block无影响。 ### Demo 4 You can only load in relatively small blocks at a time in Triton. To work with larger tensors you need to use a program id axis to run multiple blocks in parallel. Here is an example with one program axis with 3 blocks. Expected Results: Print for each [0] [1. 1. 1. 1. 1. 1. 1. 1.] Print for each [1] [1. 1. 1. 1. 1. 1. 1. 1.] Print for each [2] [1. 1. 1. 1. 0. 0. 0. 0.] Explanation: This program launch 3 blocks in parallel. For each block (pid0, 1, 2), it loads 8 elements. Note that similar to demo3, multi-dimensional tensors are flattened when we use pointer (i.e. continuous in memory). triton.jitdefdemo4(x_ptr):pidtl.program_id(0)rangetl.arange(0,8)pid*8xtl.load(x_ptrrange,range20)print(Print for each,pid,x)defrun_demo4():print(Demo4 Output: )xtorch.ones(2,4,4)demo4[(3,1,1)](x)print_end_line()