Shared Memory Matrix Multiplication
Overview
Implement a kernel that multiplies square matrices \(A\) and \(\text{transpose}(A)\) and stores the result in \(\text{out}\), using shared memory to improve memory access efficiency. This version loads matrix blocks into shared memory before computation.
Key concepts
In this puzzle, you’ll learn about:
- Block-local memory management with LayoutTensor
- Thread synchronization patterns
- Memory access optimization using shared memory
- Collaborative data loading with 2D indexing
- Efficient use of LayoutTensor for matrix operations
The key insight is understanding how to use fast shared memory with LayoutTensor to reduce expensive global memory operations.
Configuration
- Matrix size: \(\text{SIZE} \times \text{SIZE} = 2 \times 2\)
- Threads per block: \(\text{TPB} \times \text{TPB} = 3 \times 3\)
- Grid dimensions: \(1 \times 1\)
Layout configuration:
- Input A:
Layout.row_major(SIZE, SIZE)
- Input B:
Layout.row_major(SIZE, SIZE)
(transpose of A) - Output:
Layout.row_major(SIZE, SIZE)
- Shared Memory: Two
TPB Ă— TPB
LayoutTensors
Memory organization:
Global Memory (LayoutTensor): Shared Memory (LayoutTensor):
A[i,j]: Direct access a_shared[local_i, local_j]
B[i,j]: Transposed access b_shared[local_i, local_j]
Code to complete
fn single_block_matmul[
layout: Layout, size: Int
](
out: LayoutTensor[mut=False, dtype, layout],
a: LayoutTensor[mut=False, dtype, layout],
b: LayoutTensor[mut=False, dtype, layout],
):
global_i = block_dim.x * block_idx.x + thread_idx.x
global_j = block_dim.y * block_idx.y + thread_idx.y
local_i = thread_idx.x
local_j = thread_idx.y
# FILL ME IN (roughly 12 lines)
View full file: problems/p14/p14.mojo
Tips
- Load matrices to shared memory using global and local indices
- Call
barrier()
after loading - Compute dot product using shared memory indices
- Check array bounds for all operations
Running the code
To test your solution, run the following command in your terminal:
magic run p14 --single-block
Your output will look like this if the puzzle isn’t solved yet:
out: HostBuffer([0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([1.0, 3.0, 3.0, 13.0])
Solution
fn single_block_matmul[
layout: Layout, size: Int
](
out: LayoutTensor[mut=False, dtype, layout],
a: LayoutTensor[mut=False, dtype, layout],
b: LayoutTensor[mut=False, dtype, layout],
):
global_i = block_dim.x * block_idx.x + thread_idx.x
global_j = block_dim.y * block_idx.y + thread_idx.y
local_i = thread_idx.x
local_j = thread_idx.y
a_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc()
b_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc()
if global_i < size and global_j < size:
a_shared[local_i, local_j] = a[global_i, global_j]
b_shared[local_i, local_j] = b[global_i, global_j]
barrier()
if global_i < size and global_j < size:
var acc: out.element_type = 0
@parameter
for k in range(size):
acc += a_shared[local_i, k] * b_shared[k, local_j]
out[global_i, global_j] = acc
The shared memory implementation with LayoutTensor improves performance through efficient memory access patterns:
Memory Organization
Input Tensors (2Ă—2): Shared Memory (3Ă—3):
Matrix A: a_shared:
[a[0,0] a[0,1]] [s[0,0] s[0,1] s[0,2]]
[a[1,0] a[1,1]] [s[1,0] s[1,1] s[1,2]]
[s[2,0] s[2,1] s[2,2]]
Matrix B (transpose): b_shared: (similar layout)
[b[0,0] b[0,1]] [t[0,0] t[0,1] t[0,2]]
[b[1,0] b[1,1]] [t[1,0] t[1,1] t[1,2]]
[t[2,0] t[2,1] t[2,2]]
Implementation Phases:
-
Shared Memory Setup:
# Create 2D shared memory tensors using TensorBuilder a_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc() b_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc()
-
Thread Indexing:
# Global indices for matrix access global_i = block_dim.x * block_idx.x + thread_idx.x global_j = block_dim.y * block_idx.y + thread_idx.y # Local indices for shared memory local_i = thread_idx.x local_j = thread_idx.y
-
Data Loading:
# Load data into shared memory using LayoutTensor indexing if global_i < size and global_j < size: a_shared[local_i, local_j] = a[global_i, global_j] b_shared[local_i, local_j] = b[global_i, global_j]
-
Computation with Shared Memory:
# Guard ensures we only compute for valid matrix elements if global_i < size and global_j < size: # Initialize accumulator with output tensor's type var acc: out.element_type = 0 # Compile-time unrolled loop for matrix multiplication @parameter for k in range(size): acc += a_shared[local_i, k] * b_shared[k, local_j] # Write result only for threads within matrix bounds out[global_i, global_j] = acc
Key aspects:
-
Boundary Check:
if global_i < size and global_j < size
- Prevents out-of-bounds computation
- Only valid threads perform work
- Essential because TPB (3Ă—3) > SIZE (2Ă—2)
-
Accumulator Type:
var acc: out.element_type
- Uses output tensor’s element type for type safety
- Ensures consistent numeric precision
- Initialized to zero before accumulation
-
Loop Optimization:
@parameter for k in range(size)
- Unrolls the loop at compile time
- Enables better instruction scheduling
- Efficient for small, known matrix sizes
-
Result Writing:
out[global_i, global_j] = acc
- Protected by the same guard condition
- Only valid threads write results
- Maintains matrix bounds safety
-
Thread Safety and Synchronization:
-
Guard Conditions:
- Input Loading:
if global_i < size and global_j < size
- Computation: Same guard ensures thread safety
- Output Writing: Protected by the same condition
- Prevents invalid memory access and race conditions
- Input Loading:
-
Memory Access Safety:
- Shared memory: Accessed only within TPB bounds
- Global memory: Protected by size checks
- Output: Guarded writes prevent corruption
Key Language Features:
-
LayoutTensor Benefits:
- Direct 2D indexing simplifies code
- Type safety through
element_type
- Efficient memory layout handling
-
Shared Memory Allocation:
- TensorBuilder for structured allocation
- Row-major layout matching input tensors
- Proper alignment for efficient access
-
Synchronization:
barrier()
ensures shared memory consistency- Proper synchronization between load and compute
- Thread cooperation within block
Performance Optimizations:
-
Memory Access Efficiency:
- Single global memory load per element
- Multiple reuse through shared memory
- Coalesced memory access patterns
-
Thread Cooperation:
- Collaborative data loading
- Shared data reuse
- Efficient thread synchronization
-
Computational Benefits:
- Reduced global memory traffic
- Better cache utilization
- Improved instruction throughput
This implementation significantly improves performance over the naive version by:
- Reducing global memory accesses
- Enabling data reuse through shared memory
- Using efficient 2D indexing with LayoutTensor
- Maintaining proper thread synchronization