Simple Version
Configuration
- Array size:
SIZE = 8
elements - Threads per block:
TPB = 8
- Number of blocks: 1
- Shared memory:
TPB
elements
Notes:
- Data loading: Each thread loads one element using LayoutTensor access
- Memory pattern: Shared memory for intermediate results using
LayoutTensorBuild
- Thread sync: Coordination between computation phases
- Access pattern: Stride-based parallel computation
- Type safety: Leveraging LayoutTensorβs type system
Code to complete
from gpu import thread_idx, block_idx, block_dim, barrier
from layout import Layout, LayoutTensor
from layout.tensor_builder import LayoutTensorBuild as tb
from math import log2
alias TPB = 8
alias SIZE = 8
alias BLOCKS_PER_GRID = (1, 1)
alias THREADS_PER_BLOCK = (TPB, 1)
alias dtype = DType.float32
alias layout = Layout.row_major(SIZE)
fn prefix_sum_simple[
layout: Layout
](
out: LayoutTensor[mut=False, dtype, layout],
a: LayoutTensor[mut=False, dtype, layout],
size: Int,
):
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
# FILL ME IN (roughly 12 lines)
View full file: problems/p12/p12.mojo
Tips
- Load data into
shared[local_i]
- Use
offset = 1
and double it each step - Add elements where
local_i >= offset
- Call
barrier()
between steps
Running the code
To test your solution, run the following command in your terminal:
magic run p12 --simple
Your output will look like this if the puzzle isnβt solved yet:
out: DeviceBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([0.0, 1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0])
Solution
fn prefix_sum_simple[
layout: Layout
](
out: LayoutTensor[mut=False, dtype, layout],
a: LayoutTensor[mut=False, dtype, layout],
size: Int,
):
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
shared = tb[dtype]().row_major[TPB]().shared().alloc()
if global_i < size:
shared[local_i] = a[global_i]
barrier()
offset = 1
for i in range(Int(log2(Scalar[dtype](TPB)))):
if local_i >= offset and local_i < size:
shared[local_i] += shared[local_i - offset]
barrier()
offset *= 2
if global_i < size:
out[global_i] = shared[local_i]
The parallel (inclusive) prefix-sum algorithm works as follows:
Setup & Configuration
TPB
(Threads Per Block) = 8SIZE
(Array Size) = 8
Thread Mapping
thread_idx.x
: \([0, 1, 2, 3, 4, 5, 6, 7]\) (local_i
)block_idx.x
: \([0, 0, 0, 0, 0, 0, 0, 0]\)global_i
: \([0, 1, 2, 3, 4, 5, 6, 7]\) (block_idx.x * TPB + thread_idx.x
)
Initial Load to Shared Memory
Threads: Tβ Tβ Tβ Tβ Tβ Tβ
Tβ Tβ
Input array: [0 1 2 3 4 5 6 7]
shared: [0 1 2 3 4 5 6 7]
β β β β β β β β
Tβ Tβ Tβ Tβ Tβ Tβ
Tβ Tβ
Offset = 1: First Parallel Step
Active threads: \(T_1 \ldots T_7\) (where local_i β₯ 1
)
Before: [0 1 2 3 4 5 6 7]
Add: +0 +1 +2 +3 +4 +5 +6
| | | | | | |
Result: [0 1 3 6 7 9 11 13]
β β β β β β β
Tβ Tβ Tβ Tβ Tβ
Tβ Tβ
Offset = 2: Second Parallel Step
Active threads: \(T_2 \ldots T_7\) (where local_i β₯ 2
)
Before: [0 1 3 6 7 9 11 13]
Add: +0 +1 +3 +6 +7 +9
| | | | | |
Result: [0 1 3 7 10 15 18 22]
β β β β β β
Tβ Tβ Tβ Tβ
Tβ Tβ
Offset = 4: Third Parallel Step
Active threads: \(T_4 \ldots T_7\) (where local_i β₯ 4
)
Before: [0 1 3 7 10 15 18 22]
Add: +0 +1 +3 +7
| | | |
Result: [0 1 3 7 10 16 21 28]
β β β β
Tβ Tβ
Tβ Tβ
Final Write to Output
Threads: Tβ Tβ Tβ Tβ Tβ Tβ
Tβ Tβ
global_i: 0 1 2 3 4 5 6 7
out[]: [0 1 3 7 10 16 21 28]
β β β β β β β β
Tβ Tβ Tβ Tβ Tβ Tβ
Tβ Tβ
Thread-by-Thread Execution
\(T_0\) (local_i=0
):
- Loads
shared[0] = 0
- Never adds (
local_i < offset
always) - Writes
out[0] = 0
\(T_1\) (local_i=1
):
- Loads
shared[1] = 1
offset=1
: addsshared[0]
β 1offset=2,4
: no action (local_i < offset
)- Writes
out[1] = 1
\(T_2\) (local_i=2
):
- Loads
shared[2] = 2
offset=1
: addsshared[1]
β 3offset=2
: addsshared[0]
β 3offset=4
: no action- Writes
out[2] = 3
\(T_3\) (local_i=3
):
- Loads
shared[3] = 3
offset=1
: addsshared[2]
β 6offset=2
: addsshared[1]
β 7offset=4
: no action- Writes
out[3] = 7
\(T_4\) (local_i=4
):
- Loads
shared[4] = 4
offset=1
: addsshared[3]
β 7offset=2
: addsshared[2]
β 10offset=4
: addsshared[0]
β 10- Writes
out[4] = 10
\(T_5\) (local_i=5
):
- Loads
shared[5] = 5
offset=1
: addsshared[4]
β 9offset=2
: addsshared[3]
β 15offset=4
: addsshared[1]
β 16- Writes
out[5] = 16
\(T_6\) (local_i=6
):
- Loads
shared[6] = 6
offset=1
: addsshared[5]
β 11offset=2
: addsshared[4]
β 18offset=4
: addsshared[2]
β 21- Writes
out[6] = 21
\(T_7\) (local_i=7
):
- Loads
shared[7] = 7
offset=1
: addsshared[6]
β 13offset=2
: addsshared[5]
β 22offset=4
: addsshared[3]
β 28- Writes
out[7] = 28
The solution ensures correct synchronization between phases using barrier()
and handles array bounds checking with if global_i < size
. The final result produces the inclusive prefix sum where each element \(i\) contains \(\sum_{j=0}^{i} a[j]\).