Key concepts
In this puzzle, you’ll learn about:
- Using shared memory for sliding window operations
- Handling boundary conditions in pooling
- Coordinating thread access to neighboring elements
The key insight is understanding how to efficiently access a window of elements using shared memory, with special handling for the first elements in the sequence.
Configuration
- Array size:
SIZE = 8
elements - Threads per block:
TPB = 8
- Window size: 3 elements
- Shared memory:
TPB
elements
Notes:
- Window access: Each output depends on up to 3 previous elements
- Edge handling: First two positions need special treatment
- Memory pattern: One shared memory load per thread
- Thread sync: Coordination before window operations
Code to complete
alias TPB = 8
alias SIZE = 8
alias BLOCKS_PER_GRID = (1, 1)
alias THREADS_PER_BLOCK = (TPB, 1)
alias dtype = DType.float32
fn pooling(
out: UnsafePointer[Scalar[dtype]],
a: UnsafePointer[Scalar[dtype]],
size: Int,
):
shared = stack_allocation[
TPB * sizeof[dtype](),
Scalar[dtype],
address_space = AddressSpace.SHARED,
]()
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
# FILL ME IN (roughly 10 lines)
View full file: problems/p09/p09.mojo
Tips
- Load data and call
barrier()
- Special cases:
out[0] = shared[0]
,out[1] = shared[0] + shared[1]
- General case:
if 1 < global_i < size
- Sum three elements:
shared[local_i - 2] + shared[local_i - 1] + shared[local_i]
Running the code
To test your solution, run the following command in your terminal:
magic run p09
Your output will look like this if the puzzle isn’t solved yet:
out: HostBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([0.0, 1.0, 3.0, 6.0, 9.0, 12.0, 15.0, 18.0])
Solution
fn pooling(
out: UnsafePointer[Scalar[dtype]],
a: UnsafePointer[Scalar[dtype]],
size: Int,
):
shared = stack_allocation[
TPB * sizeof[dtype](),
Scalar[dtype],
address_space = AddressSpace.SHARED,
]()
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
if global_i < size:
shared[local_i] = a[global_i]
barrier()
if global_i == 0:
out[0] = shared[0]
elif global_i == 1:
out[1] = shared[0] + shared[1]
if 1 < global_i < size:
out[global_i] = (
shared[local_i - 2] + shared[local_i - 1] + shared[local_i]
)
The solution implements a sliding window sum using shared memory with these key steps:
-
Shared Memory Setup:
- Allocates
TPB
elements in shared memory - Each thread loads one element from global memory
- Uses
barrier()
to ensure all data is loaded
- Allocates
-
Boundary Cases:
- Position 0:
out[0] = shared[0]
(only first element) - Position 1:
out[1] = shared[0] + shared[1]
(sum of first two elements)
- Position 0:
-
Main Window Operation:
- For positions 2 and beyond:
out[i] = shared[i-2] + shared[i-1] + shared[i]
- Uses local indices for shared memory access
- Maintains coalesced memory access pattern
- For positions 2 and beyond:
-
Memory Access Pattern:
- One global read per thread into shared memory
- One global write per thread from shared memory
- Uses shared memory for efficient neighbor access
- Avoids redundant global memory loads
This approach is efficient because:
- Minimizes global memory access
- Uses shared memory for fast neighbor lookups
- Handles boundary conditions without branching in the main case
- Maintains good memory coalescing