GPU Kernel Optimizations

Disclaimer: These are notes for CSE 599K "LLM Serving Systems" at the University of Washington, Spring 2025 instructed by both Prof. Baris Kasikci and TA Kan Zhu

GPU Architecture Recap

GPU Programming Model

Concept Definition Corresponding Architecture Communication Limits
Thread Minimal units that execute instructions Functional units Local Up to 255 registers
Warp Group of Threads "SM tiles" Register file 32 threads
Thread Blocks Group of Warps SM Shared memory Up to 32 warps (1024 threads)
Kernel Function on GPU GPU L2 / Global memory Up to (2^32-1)^3 Blocks

Triton Framework

CUDA

GPU Optimization Techniques

How to Write Fast Kernels

Four key optimization strategies: 1. Coalesced Global Loading 2. Using Shared Memory 3. Avoiding Bank Conflicts 4. Avoiding Branch Divergence

Matrix Transpose Example

Bank Conflicts

Branch Divergence

Reduction Problem

Parallel Reduction Optimizations

  1. Reduction #1: Basic parallel reduction with divergent branching
  2. Reduction #2: Better access patterns to improve coalescing
  3. Reduction #3: Sequential addressing to eliminate bank conflicts
  4. Reduction #4: Load multiple elements per thread
  5. Reduction #5: Load even more elements per thread

  6. Trade-off: More elements loading means higher memory utilization, but number of blocks reduces, and GPU utilization goes down

GEMM (General Matrix Multiplication)

Matrix Transpose Kernel Case Study

Problem Setup

Transpose V1: Row-wise Partitioning

Transpose V2: Global Memory Coalescing

Transpose V3: Tilewise Partitioning with Shared Memory

Shared Memory Allocation Methods

Static Allocation

__shared__ float f_array[10];

Dynamic Allocation

extern __shared__ int shared_mem[];
// Launch kernel with:
my_kernel<<<grid, block, shared_mem_size_in_bytes>>>

Understanding Bank Conflicts

Bank Structure: Shared memory is organized into banks (typically 32 banks)

Transpose V4: Padding to Avoid Bank Conflicts


Reduction Kernel Case Study

Reduction Problem Definition

for elements in array:
    temp = op(temp, element)

Parallel Reduction Strategy

Instead of sequential reduction, use tree-like parallel reduction: - Step 1: 8 elements o 4 partial results - Step 2: 4 partial results o 2 partial results - Step 3: 2 partial results o 1 final result

Reduction Implementation Variants

Reduction #1: Interleaved Addressing

Branch Divergence in CUDA

Key Concept: Threads in a warp always execute the same instructions

Reduction #2: Sequential Access Pattern

Reduction #3: Sequential Accesses

Reduction #4: Load Two Elements

Reduction #5: Load More Elements


GEMM (General Matrix Multiply) Optimization

GEMM Memory Access Pattern

For matrices of size M imesK and K imesN: - Per output element: Load one row + one column = 2K elements - Total memory loads: 2MNK - Unique loads: Only MK + NK - Problem: Massive redundancy in memory access

GEMM Tiling Strategy

Load by Tiles:

Memory Load Reduction: $$L = \frac{Tile_M + Tile_N}{Tile_M \cdot Tile_N} \cdot MNK$$

Key Benefit: L2 cache access reduced by factor of tile dimensions

Tensor Cores

Definition: Special hardware units that perform small GEMM operations

GEMM Hierarchy


High-Performance Kernel Libraries

Essential Libraries

cuBLAS

CUTLASS

Raft

FlashInfer

CUB


Python Integration

Pybind11 for CUDA Kernels

Basic Pattern:

#include <pybind11/pybind11.h>
#include <torch/torch.h>
#include <torch/extension.h>
#include <cuda_runtime.h>

__global__ void add_kernel(int *a, int *b, int *c, size_t num) {
    int block_start = blockIdx.x * blockDim.x;
    int thread_id = threadIdx.x;
    int index = block_start + thread_id;
    if (index < num) {
        c[index] = a[index] + b[index];
    }
}

torch::Tensor add(torch::Tensor a, torch::Tensor b) {
    auto num = a.size(0);
    auto c = torch::empty_like(a);

    int threads_per_block = 256;
    int blocks_per_grid = (num + threads_per_block - 1) / threads_per_block;

    add_kernel<<<blocks_per_grid, threads_per_block>>>(
        a.data_ptr<int>(), b.data_ptr<int>(), c.data_ptr<int>(), num);
    cudaDeviceSynchronize();
    return c;
}

PYBIND11_MODULE(my_addition, m) {
    m.def("add", &add, "Add two tensors");
}

Key Optimization Principles Summary

  1. Memory Coalescing: Ensure contiguous memory access within warps
  2. Shared Memory: Use as high-speed cache for frequently accessed data
  3. Bank Conflict Avoidance: Pad shared memory arrays when necessary
  4. Branch Divergence Minimization: Structure algorithms to keep warps synchronized
  5. Occupancy vs Efficiency: Balance thread utilization with per-thread work
  6. Hierarchical Tiling: Optimize for different levels of memory hierarchy