Complete Version
Configuration
- Array size:
SIZE_2 = 15
elements - Threads per block:
TPB = 8
- Number of blocks: 2
- Shared memory:
TPB
elements per block
Notes:
- Block handling: Multiple blocks process array segments
- Partial blocks: Last block may not be full
- Block sums: Store running totals between blocks
- Global result: Combine local and block sums
- Layout safety: Consistent layout handling through LayoutTensor
Code to complete
alias SIZE_2 = 15
alias BLOCKS_PER_GRID_2 = (2, 1)
alias THREADS_PER_BLOCK_2 = (TPB, 1)
fn prefix_sum[
layout: Layout
](
out: LayoutTensor[mut=False, dtype, layout],
a: LayoutTensor[mut=False, dtype, layout],
size: Int,
):
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
# FILL ME IN (roughly 19 lines)
View full file: problems/p12/p12.mojo
Tips
- Compute local prefix sums like in Simple Version
- Last thread stores block sum at
TPB * (block_idx.x + 1)
- Add previous block’s sum to current block
- Handle array bounds for all operations
Running the code
To test your solution, run the following command in your terminal:
magic run p12 --complete
Your output will look like this if the puzzle isn’t solved yet:
out: DeviceBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([0.0, 1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0])
Solution
fn prefix_sum[
layout: Layout
](
out: LayoutTensor[mut=False, dtype, layout],
a: LayoutTensor[mut=False, dtype, layout],
size: Int,
):
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
shared = tb[dtype]().row_major[TPB]().shared().alloc()
if global_i < size:
shared[local_i] = a[global_i]
barrier()
# Idea: two passes
# SIZE=15, TPB=8, BLOCKS=2
# buffer: [0,1,2,...,7 | 8,...,14]
# Step 1: Each block computes local prefix sum
# Block 0: [0,1,2,3,4,5,6,7] → [0,1,3,6,10,15,21,28]
# Block 1: [8,9,10,11,12,13,14] → [8,17,27,38,50,63,77]
# Step 2: Store block sums
# Block 0's sum (28) → position 8
# Step 3: Add previous block's sum
# Block 1: Each element += 28
# [8,17,27,38,50,63,77] → [36,45,55,66,78,91,105]
# Final result combines both blocks:
# [0,1,3,6,10,15,21,28, 36,45,55,66,78,91,105]
# local prefix-sum for each block
offset = 1
for i in range(Int(log2(Scalar[dtype](TPB)))):
if local_i >= offset and local_i < size:
shared[local_i] += shared[local_i - offset]
barrier()
offset *= 2
# store block results
if global_i < size:
out[global_i] = shared[local_i]
# store block sum in first element of next block:
# - Only last thread (local_i == 7) in each block except last block executes
# - Block 0: Thread 7 stores 28 (sum of 0-7) at position 8 (start of Block 1)
# - Calculation: TPB * (block_idx.x + 1)
# Block 0: 8 * (0 + 1) = position 8
# Block 1: No action (last block)
# Memory state:
# [0,1,3,6,10,15,21,28 | 28,45,55,66,78,91,105]
# ↑
# Block 0's sum stored here
if local_i == TPB - 1 and block_idx.x < size // TPB - 1:
out[TPB * (block_idx.x + 1)] = shared[local_i]
# wait for all blocks to store their sums
barrier()
# second pass: add previous block's sum which becomes:
# Before: [8,9,10,11,12,13,14]
# Add 28: [36,37,38,39,40,41,42]
if block_idx.x > 0 and global_i < size:
shared[local_i] += out[block_idx.x * TPB - 1]
# final result
if global_i < size:
out[global_i] = shared[local_i]
This solution handles multi-block prefix sum in three main phases:
-
Local prefix sum (per block):
Block 0 (8 elements): [0,1,2,3,4,5,6,7] After local prefix sum: [0,1,3,7,10,16,21,28] Block 1 (7 elements): [8,9,10,11,12,13,14] After local prefix sum: [8,17,27,38,50,63,77]
-
Block sum communication:
- Last thread (local_i == TPB-1) in each non-final block
- Stores its block’s sum at next block’s start:
if local_i == TPB - 1 and block_idx.x < size // TPB - 1: out[TPB * (block_idx.x + 1)] = shared[local_i]
- Block 0’s sum (28) stored at position 8
- Memory layout:
[0,1,3,7,10,16,21,28 | 28,17,27,38,50,63,77]
↑ Block 0’s sum
-
Final adjustment:
- Each block after first adds previous block’s sum
if block_idx.x > 0 and global_i < size: shared[local_i] += out[block_idx.x * TPB - 1]
- Block 1: Each element += 28
- Final result:
[0,1,3,7,10,16,21,28, 36,45,55,66,78,91,105]
Key implementation details:
- Uses
barrier()
after shared memory operations - Handles partial blocks (last block size < TPB)
- Guards all operations with proper bounds checking
- Maintains correct thread and block synchronization
- Achieves \(O(\log n)\) complexity per block
The solution scales to arbitrary-sized inputs by combining local prefix sums with efficient block-to-block communication.