Overview
Implement a kernel that adds 10 to each position of 2D square matrix a and stores it in 2D square matrix output.
Note: You have more threads than positions.
Key concepts
In this puzzle, you’ll learn about:
- Working with 2D thread indices (
thread_idx.x,thread_idx.y) - Converting 2D coordinates to 1D memory indices
- Handling boundary checks in two dimensions
The key insight is understanding how to map from 2D thread coordinates \((i,j)\) to elements in a row-major matrix of size \(n \times n\), while ensuring thread indices are within bounds.
- 2D indexing: Each thread has a unique \((i,j)\) position
- Memory layout: Row-major ordering maps 2D to 1D memory
- Guard condition: Need bounds checking in both dimensions
- Thread bounds: More threads \((3 \times 3)\) than matrix elements \((2 \times 2)\)
Code to complete
alias SIZE = 2
alias BLOCKS_PER_GRID = 1
alias THREADS_PER_BLOCK = (3, 3)
alias dtype = DType.float32
fn add_10_2d(
output: UnsafePointer[Scalar[dtype], MutAnyOrigin],
a: UnsafePointer[Scalar[dtype], MutAnyOrigin],
size: Int,
):
row = thread_idx.y
col = thread_idx.x
# FILL ME IN (roughly 2 lines)
View full file: problems/p04/p04.mojo
Tips
- Get 2D indices:
row = thread_idx.y,col = thread_idx.x - Add guard:
if row < size and col < size - Inside guard add 10 in row-major way!
Running the code
To test your solution, run the following command in your terminal:
pixi run p04
pixi run -e amd p04
pixi run -e apple p04
uv run poe p04
Your output will look like this if the puzzle isn’t solved yet:
out: HostBuffer([0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([10.0, 11.0, 12.0, 13.0])
Solution
fn add_10_2d(
output: UnsafePointer[Scalar[dtype], MutAnyOrigin],
a: UnsafePointer[Scalar[dtype], MutAnyOrigin],
size: Int,
):
row = thread_idx.y
col = thread_idx.x
if row < size and col < size:
output[row * size + col] = a[row * size + col] + 10.0
This solution:
- Get 2D indices:
row = thread_idx.y,col = thread_idx.x - Add guard:
if row < size and col < size - Inside guard:
output[row * size + col] = a[row * size + col] + 10.0