Puzzle 13: Axis Sum
Overview
Implement a kernel that computes a sum over each row of 2D matrix a
and stores it in out
using LayoutTensor.
Key concepts
In this puzzle, you’ll learn about:
- Parallel reduction along matrix dimensions using LayoutTensor
- Using block coordinates for data partitioning
- Efficient shared memory reduction patterns
- Working with multi-dimensional tensor layouts
The key insight is understanding how to map thread blocks to matrix rows and perform efficient parallel reduction within each block while leveraging LayoutTensor’s dimensional indexing.
Configuration
- Matrix dimensions: \(\text{BATCH} \times \text{SIZE} = 4 \times 6\)
- Threads per block: \(\text{TPB} = 8\)
- Grid dimensions: \(1 \times \text{BATCH}\)
- Shared memory: \(\text{TPB}\) elements per block
- Input layout:
Layout.row_major(BATCH, SIZE)
- Output layout:
Layout.row_major(BATCH, 1)
Matrix visualization:
Row 0: [0, 1, 2, 3, 4, 5] → Block(0,0)
Row 1: [6, 7, 8, 9, 10, 11] → Block(0,1)
Row 2: [12, 13, 14, 15, 16, 17] → Block(0,2)
Row 3: [18, 19, 20, 21, 22, 23] → Block(0,3)
Code to Complete
from gpu import thread_idx, block_idx, block_dim, barrier
from layout import Layout, LayoutTensor
from layout.tensor_builder import LayoutTensorBuild as tb
alias TPB = 8
alias BATCH = 4
alias SIZE = 6
alias BLOCKS_PER_GRID = (1, BATCH)
alias THREADS_PER_BLOCK = (TPB, 1)
alias dtype = DType.float32
alias in_layout = Layout.row_major(BATCH, SIZE)
alias out_layout = Layout.row_major(BATCH, 1)
fn axis_sum[
in_layout: Layout, out_layout: Layout
](
out: LayoutTensor[mut=False, dtype, out_layout],
a: LayoutTensor[mut=False, dtype, in_layout],
size: Int,
):
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
batch = block_idx.y
# FILL ME IN (roughly 15 lines)
View full file: problems/p13/p13.mojo
Tips
- Use
batch = block_idx.y
to select row - Load elements:
cache[local_i] = a[batch * size + local_i]
- Perform parallel reduction with halving stride
- Thread 0 writes final sum to
out[batch]
Running the Code
To test your solution, run the following command in your terminal:
magic run p13
Your output will look like this if the puzzle isn’t solved yet:
out: DeviceBuffer([0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([15.0, 51.0, 87.0, 123.0])
Solution
fn axis_sum[
in_layout: Layout, out_layout: Layout
](
out: LayoutTensor[mut=False, dtype, out_layout],
a: LayoutTensor[mut=False, dtype, in_layout],
size: Int,
):
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
batch = block_idx.y
cache = tb[dtype]().row_major[TPB]().shared().alloc()
# Visualize:
# Block(0,0): [T0,T1,T2,T3,T4,T5,T6,T7] -> Row 0: [0,1,2,3,4,5]
# Block(0,1): [T0,T1,T2,T3,T4,T5,T6,T7] -> Row 1: [6,7,8,9,10,11]
# Block(0,2): [T0,T1,T2,T3,T4,T5,T6,T7] -> Row 2: [12,13,14,15,16,17]
# Block(0,3): [T0,T1,T2,T3,T4,T5,T6,T7] -> Row 3: [18,19,20,21,22,23]
# each row is handled by each block bc we have grid_dim=(1, BATCH)
if local_i < size:
cache[local_i] = a[batch, local_i]
else:
# Add zero-initialize padding elements for later reduction
cache[local_i] = 0
barrier()
# do reduction sum per each block
stride = TPB // 2
while stride > 0:
if local_i < stride:
cache[local_i] += cache[local_i + stride]
barrier()
stride //= 2
# writing with local thread = 0 that has the sum for each batch
if local_i == 0:
out[batch, 0] = cache[0]
The solution implements a parallel row-wise sum reduction for a 2D matrix using LayoutTensor. Here’s a comprehensive breakdown:
Matrix Layout and Block Mapping
Input Matrix (4×6) with LayoutTensor: Block Assignment:
[[ a[0,0] a[0,1] a[0,2] a[0,3] a[0,4] a[0,5] ] → Block(0,0)
[ a[1,0] a[1,1] a[1,2] a[1,3] a[1,4] a[1,5] ] → Block(0,1)
[ a[2,0] a[2,1] a[2,2] a[2,3] a[2,4] a[2,5] ] → Block(0,2)
[ a[3,0] a[3,1] a[3,2] a[3,3] a[3,4] a[3,5] ] → Block(0,3)
Parallel Reduction Process
-
Initial Data Loading:
Block(0,0): cache = [a[0,0] a[0,1] a[0,2] a[0,3] a[0,4] a[0,5] * *] // * = padding Block(0,1): cache = [a[1,0] a[1,1] a[1,2] a[1,3] a[1,4] a[1,5] * *] Block(0,2): cache = [a[2,0] a[2,1] a[2,2] a[2,3] a[2,4] a[2,5] * *] Block(0,3): cache = [a[3,0] a[3,1] a[3,2] a[3,3] a[3,4] a[3,5] * *]
-
Reduction Steps (for Block 0,0):
Initial: [0 1 2 3 4 5 * *] Stride 4: [4 5 6 7 4 5 * *] Stride 2: [10 12 6 7 4 5 * *] Stride 1: [15 12 6 7 4 5 * *]
Key Implementation Features:
-
Layout Configuration:
- Input: row-major layout (BATCH × SIZE)
- Output: row-major layout (BATCH × 1)
- Each block processes one complete row
-
Memory Access Pattern:
- LayoutTensor 2D indexing for input:
a[batch, local_i]
- Shared memory for efficient reduction
- LayoutTensor 2D indexing for output:
out[batch, 0]
- LayoutTensor 2D indexing for input:
-
Parallel Reduction Logic:
stride = TPB // 2 while stride > 0: if local_i < size: cache[local_i] += cache[local_i + stride] barrier() stride //= 2
-
Output Writing:
if local_i == 0: out[batch, 0] = cache[0] --> One result per batch
Performance Optimizations:
-
Memory Efficiency:
- Coalesced memory access through LayoutTensor
- Shared memory for fast reduction
- Single write per row result
-
Thread Utilization:
- Perfect load balancing across rows
- No thread divergence in main computation
- Efficient parallel reduction pattern
-
Synchronization:
- Minimal barriers (only during reduction)
- Independent processing between rows
- No inter-block communication needed
Complexity Analysis:
- Time: \(O(\log n)\) per row, where n is row length
- Space: \(O(TPB)\) shared memory per block
- Total parallel time: \(O(\log n)\) with sufficient threads