Complete Version

Configuration

  • Array size: SIZE_2 = 15 elements
  • Threads per block: TPB = 8
  • Number of blocks: 2
  • Shared memory: TPB elements per block

Notes:

  • Block handling: Multiple blocks process array segments
  • Partial blocks: Last block may not be full
  • Block sums: Store running totals between blocks
  • Global result: Combine local and block sums
  • Layout safety: Consistent layout handling through LayoutTensor

Code to complete

alias SIZE_2 = 15
alias BLOCKS_PER_GRID_2 = (2, 1)
alias THREADS_PER_BLOCK_2 = (TPB, 1)


fn prefix_sum[
    layout: Layout
](
    out: LayoutTensor[mut=False, dtype, layout],
    a: LayoutTensor[mut=False, dtype, layout],
    size: Int,
):
    global_i = block_dim.x * block_idx.x + thread_idx.x
    local_i = thread_idx.x
    # FILL ME IN (roughly 19 lines)


View full file: problems/p12/p12.mojo

Tips
  1. Compute local prefix sums like in Simple Version
  2. Last thread stores block sum at TPB * (block_idx.x + 1)
  3. Add previous block’s sum to current block
  4. Handle array bounds for all operations

Running the code

To test your solution, run the following command in your terminal:

magic run p12 --complete

Your output will look like this if the puzzle isn’t solved yet:

out: DeviceBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([0.0, 1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0])

Solution

fn prefix_sum[
    layout: Layout
](
    out: LayoutTensor[mut=False, dtype, layout],
    a: LayoutTensor[mut=False, dtype, layout],
    size: Int,
):
    global_i = block_dim.x * block_idx.x + thread_idx.x
    local_i = thread_idx.x
    shared = tb[dtype]().row_major[TPB]().shared().alloc()
    if global_i < size:
        shared[local_i] = a[global_i]

    barrier()

    # Idea: two passes
    # SIZE=15, TPB=8, BLOCKS=2
    # buffer: [0,1,2,...,7 | 8,...,14]

    # Step 1: Each block computes local prefix sum
    # Block 0: [0,1,2,3,4,5,6,7] → [0,1,3,6,10,15,21,28]
    # Block 1: [8,9,10,11,12,13,14] → [8,17,27,38,50,63,77]

    # Step 2: Store block sums
    # Block 0's sum (28) → position 8

    # Step 3: Add previous block's sum
    # Block 1: Each element += 28
    # [8,17,27,38,50,63,77] → [36,45,55,66,78,91,105]

    # Final result combines both blocks:
    # [0,1,3,6,10,15,21,28, 36,45,55,66,78,91,105]

    # local prefix-sum for each block
    offset = 1
    for i in range(Int(log2(Scalar[dtype](TPB)))):
        if local_i >= offset and local_i < size:
            shared[local_i] += shared[local_i - offset]

        barrier()
        offset *= 2

    # store block results
    if global_i < size:
        out[global_i] = shared[local_i]

    # store block sum in first element of next block:
    # - Only last thread (local_i == 7) in each block except last block executes
    # - Block 0: Thread 7 stores 28 (sum of 0-7) at position 8 (start of Block 1)
    # - Calculation: TPB * (block_idx.x + 1)
    # Block 0: 8 * (0 + 1) = position 8
    # Block 1: No action (last block)
    # Memory state:
    # [0,1,3,6,10,15,21,28 | 28,45,55,66,78,91,105]
    #                        ↑
    #                       Block 0's sum stored here
    if local_i == TPB - 1 and block_idx.x < size // TPB - 1:
        out[TPB * (block_idx.x + 1)] = shared[local_i]

    # wait for all blocks to store their sums
    barrier()

    # second pass: add previous block's sum which becomes:
    # Before: [8,9,10,11,12,13,14]
    # Add 28: [36,37,38,39,40,41,42]
    if block_idx.x > 0 and global_i < size:
        shared[local_i] += out[block_idx.x * TPB - 1]

    # final result
    if global_i < size:
        out[global_i] = shared[local_i]


This solution handles multi-block prefix sum in three main phases:
  1. Local prefix sum (per block):

    Block 0 (8 elements):    [0,1,2,3,4,5,6,7]
    After local prefix sum:  [0,1,3,7,10,16,21,28]
    
    Block 1 (7 elements):    [8,9,10,11,12,13,14]
    After local prefix sum:  [8,17,27,38,50,63,77]
    
  2. Block sum communication:

    • Last thread (local_i == TPB-1) in each non-final block
    • Stores its block’s sum at next block’s start:
    if local_i == TPB - 1 and block_idx.x < size // TPB - 1:
        out[TPB * (block_idx.x + 1)] = shared[local_i]
    
    • Block 0’s sum (28) stored at position 8
    • Memory layout: [0,1,3,7,10,16,21,28 | 28,17,27,38,50,63,77] ↑ Block 0’s sum
  3. Final adjustment:

    • Each block after first adds previous block’s sum
    if block_idx.x > 0 and global_i < size:
        shared[local_i] += out[block_idx.x * TPB - 1]
    
    • Block 1: Each element += 28
    • Final result: [0,1,3,7,10,16,21,28, 36,45,55,66,78,91,105]

Key implementation details:

  • Uses barrier() after shared memory operations
  • Handles partial blocks (last block size < TPB)
  • Guards all operations with proper bounds checking
  • Maintains correct thread and block synchronization
  • Achieves \(O(\log n)\) complexity per block

The solution scales to arbitrary-sized inputs by combining local prefix sums with efficient block-to-block communication.