Key concepts

In this puzzle, you’ll learn about:

  • Using shared memory within thread blocks
  • Synchronizing threads with barriers
  • Managing block-local data storage

The key insight is understanding how shared memory provides fast, block-local storage that all threads in a block can access, requiring careful coordination between threads.

Configuration

  • Array size: SIZE = 8 elements
  • Threads per block: TPB = 4
  • Number of blocks: 2
  • Shared memory: TPB elements per block

Notes:

  • Shared memory: Fast storage shared by threads in a block
  • Thread sync: Coordination using barrier()
  • Memory scope: Shared memory only visible within block
  • Access pattern: Local vs global indexing

Warning: Each block can only have a constant amount of shared memory that threads in that block can read and write to. This needs to be a literal python constant, not a variable. After writing to shared memory you need to call barrier to ensure that threads do not cross.

Code to complete

alias TPB = 4
alias SIZE = 8
alias BLOCKS_PER_GRID = (2, 1)
alias THREADS_PER_BLOCK = (TPB, 1)
alias dtype = DType.float32


fn add_10_shared(
    out: UnsafePointer[Scalar[dtype]],
    a: UnsafePointer[Scalar[dtype]],
    size: Int,
):
    shared = stack_allocation[
        TPB * sizeof[dtype](),
        Scalar[dtype],
        address_space = AddressSpace.SHARED,
    ]()
    global_i = block_dim.x * block_idx.x + thread_idx.x
    local_i = thread_idx.x
    # local data into shared memory
    if global_i < size:
        shared[local_i] = a[global_i]

    # wait for all threads to complete
    # works within a thread block
    barrier()

    # FILL ME IN (roughly 2 lines)


View full file: problems/p08/p08.mojo

Tips
  1. Wait for shared memory load with barrier()
  2. Use local_i to access shared memory: shared[local_i]
  3. Use global_i for output: out[global_i]
  4. Add guard: if global_i < size

Running the code

To test your solution, run the following command in your terminal:

magic run p08

Your output will look like this if the puzzle isn’t solved yet:

out: HostBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0])

Solution

fn add_10_shared(
    out: UnsafePointer[Scalar[dtype]],
    a: UnsafePointer[Scalar[dtype]],
    size: Int,
):
    shared = stack_allocation[
        TPB * sizeof[dtype](),
        Scalar[dtype],
        address_space = AddressSpace.SHARED,
    ]()
    global_i = block_dim.x * block_idx.x + thread_idx.x
    local_i = thread_idx.x
    # local data into shared memory
    if global_i < size:
        shared[local_i] = a[global_i]

    # wait for all threads to complete
    # works within a thread block
    barrier()

    # process using shared memory
    if global_i < size:
        out[global_i] = shared[local_i] + 10


This solution:

  • Waits for shared memory load with barrier()
  • Guards against out-of-bounds with if global_i < size
  • Reads from shared memory using shared[local_i]
  • Writes result to global memory at out[global_i]