Key concepts
In this puzzle, you’ll learn about:
- Using shared memory within thread blocks
- Synchronizing threads with barriers
- Managing block-local data storage
The key insight is understanding how shared memory provides fast, block-local storage that all threads in a block can access, requiring careful coordination between threads.
Configuration
- Array size:
SIZE = 8
elements - Threads per block:
TPB = 4
- Number of blocks: 2
- Shared memory:
TPB
elements per block
Notes:
- Shared memory: Fast storage shared by threads in a block
- Thread sync: Coordination using
barrier()
- Memory scope: Shared memory only visible within block
- Access pattern: Local vs global indexing
Warning: Each block can only have a constant amount of shared memory that threads in that block can read and write to. This needs to be a literal python constant, not a variable. After writing to shared memory you need to call barrier to ensure that threads do not cross.
Code to complete
alias TPB = 4
alias SIZE = 8
alias BLOCKS_PER_GRID = (2, 1)
alias THREADS_PER_BLOCK = (TPB, 1)
alias dtype = DType.float32
fn add_10_shared(
out: UnsafePointer[Scalar[dtype]],
a: UnsafePointer[Scalar[dtype]],
size: Int,
):
shared = stack_allocation[
TPB * sizeof[dtype](),
Scalar[dtype],
address_space = AddressSpace.SHARED,
]()
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
# local data into shared memory
if global_i < size:
shared[local_i] = a[global_i]
# wait for all threads to complete
# works within a thread block
barrier()
# FILL ME IN (roughly 2 lines)
View full file: problems/p08/p08.mojo
Tips
- Wait for shared memory load with
barrier()
- Use
local_i
to access shared memory:shared[local_i]
- Use
global_i
for output:out[global_i]
- Add guard:
if global_i < size
Running the code
To test your solution, run the following command in your terminal:
magic run p08
Your output will look like this if the puzzle isn’t solved yet:
out: HostBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0])
Solution
fn add_10_shared(
out: UnsafePointer[Scalar[dtype]],
a: UnsafePointer[Scalar[dtype]],
size: Int,
):
shared = stack_allocation[
TPB * sizeof[dtype](),
Scalar[dtype],
address_space = AddressSpace.SHARED,
]()
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
# local data into shared memory
if global_i < size:
shared[local_i] = a[global_i]
# wait for all threads to complete
# works within a thread block
barrier()
# process using shared memory
if global_i < size:
out[global_i] = shared[local_i] + 10
This solution:
- Waits for shared memory load with
barrier()
- Guards against out-of-bounds with
if global_i < size
- Reads from shared memory using
shared[local_i]
- Writes result to global memory at
out[global_i]