Key concepts
In this puzzle, youāll learn about:
- Similar to the puzzle 8 and puzzle 9, implementing parallel reduction with LayoutTensor
- Managing shared memory using
LayoutTensorBuilder
- Coordinating threads for collective operations
- Using layout-aware tensor operations
The key insight is how LayoutTensor simplifies memory management while maintaining efficient parallel reduction patterns.
Configuration
- Vector size:
SIZE = 8
elements - Threads per block:
TPB = 8
- Number of blocks: 1
- Output size: 1 element
- Shared memory:
TPB
elements
Notes:
- Tensor builder: Use
LayoutTensorBuilder[dtype]().row_major[TPB]().shared().alloc()
- Element access: Natural indexing with bounds checking
- Layout handling: Separate layouts for input and output
- Thread coordination: Same synchronization patterns with
barrier()
Code to complete
from gpu import thread_idx, block_idx, block_dim, barrier
from layout import Layout, LayoutTensor
from layout.tensor_builder import LayoutTensorBuild as tb
alias TPB = 8
alias SIZE = 8
alias BLOCKS_PER_GRID = (1, 1)
alias THREADS_PER_BLOCK = (SIZE, 1)
alias dtype = DType.float32
alias layout = Layout.row_major(SIZE)
alias out_layout = Layout.row_major(1)
fn dot_product[
in_layout: Layout, out_layout: Layout
](
out: LayoutTensor[mut=True, dtype, out_layout],
a: LayoutTensor[mut=True, dtype, in_layout],
b: LayoutTensor[mut=True, dtype, in_layout],
size: Int,
):
# FILL ME IN (roughly 13 lines)
...
View full file: problems/p10/p10_layout_tensor.mojo
Tips
- Create shared memory with tensor builder
- Store
a[global_i] * b[global_i]
inshared[local_i]
- Use parallel reduction pattern with
barrier()
- Let thread 0 write final result to
out[0]
Running the code
To test your solution, run the following command in your terminal:
magic run p10_layout_tensor
Your output will look like this if the puzzle isnāt solved yet:
out: HostBuffer([0.0])
expected: HostBuffer([140.0])
Solution
fn dot_product[
in_layout: Layout, out_layout: Layout
](
out: LayoutTensor[mut=True, dtype, out_layout],
a: LayoutTensor[mut=True, dtype, in_layout],
b: LayoutTensor[mut=True, dtype, in_layout],
size: Int,
):
shared = tb[dtype]().row_major[TPB]().shared().alloc()
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
# Compute element-wise multiplication into shared memory
if global_i < size:
shared[local_i] = a[global_i] * b[global_i]
# Synchronize threads within block
barrier()
# Parallel reduction in shared memory
stride = TPB // 2
while stride > 0:
if local_i < stride:
shared[local_i] += shared[local_i + stride]
barrier()
stride //= 2
# Only thread 0 writes the final result
if local_i == 0:
out[0] = shared[0]
The solution implements a parallel reduction for dot product using LayoutTensor. Hereās the detailed breakdown:
Phase 1: Element-wise Multiplication
Each thread performs one multiplication with natural indexing:
shared[local_i] = a[global_i] * b[global_i]
Phase 2: Parallel Reduction
Tree-based reduction with layout-aware operations:
Initial: [0*0 1*1 2*2 3*3 4*4 5*5 6*6 7*7]
= [0 1 4 9 16 25 36 49]
Step 1: [0+16 1+25 4+36 9+49 16 25 36 49]
= [16 26 40 58 16 25 36 49]
Step 2: [16+40 26+58 40 58 16 25 36 49]
= [56 84 40 58 16 25 36 49]
Step 3: [56+84 84 40 58 16 25 36 49]
= [140 84 40 58 16 25 36 49]
Key Implementation Features:
-
Memory Management:
- Clean shared memory allocation with tensor builder
- Type-safe operations with LayoutTensor
- Automatic bounds checking
- Layout-aware indexing
-
Thread Synchronization:
barrier()
after initial multiplicationbarrier()
between reduction steps- Safe thread coordination
-
Reduction Logic:
stride = TPB // 2 while stride > 0: if local_i < stride: shared[local_i] += shared[local_i + stride] barrier() stride //= 2
-
Performance Benefits:
- \(O(\log n)\) time complexity
- Coalesced memory access
- Minimal thread divergence
- Efficient shared memory usage
The LayoutTensor version maintains the same efficient parallel reduction while providing:
- Better type safety
- Cleaner memory management
- Layout awareness
- Natural indexing syntax
Barrier Synchronization Importance
The barrier()
between reduction steps is critical for correctness. Hereās why:
Without barrier()
, race conditions occur:
Initial shared memory: [0 1 4 9 16 25 36 49]
Step 1 (stride = 4):
Thread 0 reads: shared[0] = 0, shared[4] = 16
Thread 1 reads: shared[1] = 1, shared[5] = 25
Thread 2 reads: shared[2] = 4, shared[6] = 36
Thread 3 reads: shared[3] = 9, shared[7] = 49
Without barrier:
- Thread 0 writes: shared[0] = 0 + 16 = 16
- Thread 1 starts next step (stride = 2) before Thread 0 finishes
and reads old value shared[0] = 0 instead of 16!
With barrier()
:
Step 1 (stride = 4):
All threads write their sums:
[16 26 40 58 16 25 36 49]
barrier() ensures ALL threads see these values
Step 2 (stride = 2):
Now threads safely read the updated values:
Thread 0: shared[0] = 16 + 40 = 56
Thread 1: shared[1] = 26 + 58 = 84
The barrier()
ensures:
- All writes from current step complete
- All threads see updated values
- No thread starts next iteration early
- Consistent shared memory state
Without these synchronization points, we could get:
- Memory race conditions
- Threads reading stale values
- Non-deterministic results
- Incorrect final sum