Shared Memory Matrix Multiplication

Overview

Implement a kernel that multiplies square matrices \(A\) and \(\text{transpose}(A)\) and stores the result in \(\text{out}\), using shared memory to improve memory access efficiency. This version loads matrix blocks into shared memory before computation.

Key concepts

In this puzzle, you’ll learn about:

  • Block-local memory management with LayoutTensor
  • Thread synchronization patterns
  • Memory access optimization using shared memory
  • Collaborative data loading with 2D indexing
  • Efficient use of LayoutTensor for matrix operations

The key insight is understanding how to use fast shared memory with LayoutTensor to reduce expensive global memory operations.

Configuration

  • Matrix size: \(\text{SIZE} \times \text{SIZE} = 2 \times 2\)
  • Threads per block: \(\text{TPB} \times \text{TPB} = 3 \times 3\)
  • Grid dimensions: \(1 \times 1\)

Layout configuration:

  • Input A: Layout.row_major(SIZE, SIZE)
  • Input B: Layout.row_major(SIZE, SIZE) (transpose of A)
  • Output: Layout.row_major(SIZE, SIZE)
  • Shared Memory: Two TPB Ă— TPB LayoutTensors

Memory organization:

Global Memory (LayoutTensor):          Shared Memory (LayoutTensor):
A[i,j]: Direct access                  a_shared[local_i, local_j]
B[i,j]: Transposed access              b_shared[local_i, local_j]

Code to complete

fn single_block_matmul[
    layout: Layout, size: Int
](
    out: LayoutTensor[mut=False, dtype, layout],
    a: LayoutTensor[mut=False, dtype, layout],
    b: LayoutTensor[mut=False, dtype, layout],
):
    global_i = block_dim.x * block_idx.x + thread_idx.x
    global_j = block_dim.y * block_idx.y + thread_idx.y
    local_i = thread_idx.x
    local_j = thread_idx.y
    # FILL ME IN (roughly 12 lines)


View full file: problems/p14/p14.mojo

Tips
  1. Load matrices to shared memory using global and local indices
  2. Call barrier() after loading
  3. Compute dot product using shared memory indices
  4. Check array bounds for all operations

Running the code

To test your solution, run the following command in your terminal:

magic run p14 --single-block

Your output will look like this if the puzzle isn’t solved yet:

out: HostBuffer([0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([1.0, 3.0, 3.0, 13.0])

Solution

fn single_block_matmul[
    layout: Layout, size: Int
](
    out: LayoutTensor[mut=False, dtype, layout],
    a: LayoutTensor[mut=False, dtype, layout],
    b: LayoutTensor[mut=False, dtype, layout],
):
    global_i = block_dim.x * block_idx.x + thread_idx.x
    global_j = block_dim.y * block_idx.y + thread_idx.y
    local_i = thread_idx.x
    local_j = thread_idx.y
    a_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc()
    b_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc()

    if global_i < size and global_j < size:
        a_shared[local_i, local_j] = a[global_i, global_j]
        b_shared[local_i, local_j] = b[global_i, global_j]

    barrier()

    if global_i < size and global_j < size:
        var acc: out.element_type = 0

        @parameter
        for k in range(size):
            acc += a_shared[local_i, k] * b_shared[k, local_j]

        out[global_i, global_j] = acc


The shared memory implementation with LayoutTensor improves performance through efficient memory access patterns:

Memory Organization

Input Tensors (2Ă—2):                Shared Memory (3Ă—3):
Matrix A:                           a_shared:
 [a[0,0] a[0,1]]                    [s[0,0] s[0,1] s[0,2]]
 [a[1,0] a[1,1]]                    [s[1,0] s[1,1] s[1,2]]
                                    [s[2,0] s[2,1] s[2,2]]
Matrix B (transpose):               b_shared: (similar layout)
 [b[0,0] b[0,1]]                    [t[0,0] t[0,1] t[0,2]]
 [b[1,0] b[1,1]]                    [t[1,0] t[1,1] t[1,2]]
                                    [t[2,0] t[2,1] t[2,2]]

Implementation Phases:

  1. Shared Memory Setup:

    # Create 2D shared memory tensors using TensorBuilder
    a_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc()
    b_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc()
    
  2. Thread Indexing:

    # Global indices for matrix access
    global_i = block_dim.x * block_idx.x + thread_idx.x
    global_j = block_dim.y * block_idx.y + thread_idx.y
    
    # Local indices for shared memory
    local_i = thread_idx.x
    local_j = thread_idx.y
    
  3. Data Loading:

    # Load data into shared memory using LayoutTensor indexing
    if global_i < size and global_j < size:
        a_shared[local_i, local_j] = a[global_i, global_j]
        b_shared[local_i, local_j] = b[global_i, global_j]
    
  4. Computation with Shared Memory:

    # Guard ensures we only compute for valid matrix elements
    if global_i < size and global_j < size:
        # Initialize accumulator with output tensor's type
        var acc: out.element_type = 0
    
        # Compile-time unrolled loop for matrix multiplication
        @parameter
        for k in range(size):
            acc += a_shared[local_i, k] * b_shared[k, local_j]
    
        # Write result only for threads within matrix bounds
        out[global_i, global_j] = acc
    

    Key aspects:

    • Boundary Check: if global_i < size and global_j < size

      • Prevents out-of-bounds computation
      • Only valid threads perform work
      • Essential because TPB (3Ă—3) > SIZE (2Ă—2)
    • Accumulator Type: var acc: out.element_type

      • Uses output tensor’s element type for type safety
      • Ensures consistent numeric precision
      • Initialized to zero before accumulation
    • Loop Optimization: @parameter for k in range(size)

      • Unrolls the loop at compile time
      • Enables better instruction scheduling
      • Efficient for small, known matrix sizes
    • Result Writing: out[global_i, global_j] = acc

      • Protected by the same guard condition
      • Only valid threads write results
      • Maintains matrix bounds safety

Thread Safety and Synchronization:

  1. Guard Conditions:

    • Input Loading: if global_i < size and global_j < size
    • Computation: Same guard ensures thread safety
    • Output Writing: Protected by the same condition
    • Prevents invalid memory access and race conditions
  2. Memory Access Safety:

    • Shared memory: Accessed only within TPB bounds
    • Global memory: Protected by size checks
    • Output: Guarded writes prevent corruption

Key Language Features:

  1. LayoutTensor Benefits:

    • Direct 2D indexing simplifies code
    • Type safety through element_type
    • Efficient memory layout handling
  2. Shared Memory Allocation:

    • TensorBuilder for structured allocation
    • Row-major layout matching input tensors
    • Proper alignment for efficient access
  3. Synchronization:

    • barrier() ensures shared memory consistency
    • Proper synchronization between load and compute
    • Thread cooperation within block

Performance Optimizations:

  1. Memory Access Efficiency:

    • Single global memory load per element
    • Multiple reuse through shared memory
    • Coalesced memory access patterns
  2. Thread Cooperation:

    • Collaborative data loading
    • Shared data reuse
    • Efficient thread synchronization
  3. Computational Benefits:

    • Reduced global memory traffic
    • Better cache utilization
    • Improved instruction throughput

This implementation significantly improves performance over the naive version by:

  • Reducing global memory accesses
  • Enabling data reuse through shared memory
  • Using efficient 2D indexing with LayoutTensor
  • Maintaining proper thread synchronization