TileTensor Version
Overview
Implement a kernel that adds 10 to each position of 2D TileTensor a and
stores it in 2D TileTensor output.
Note: You have more threads than positions.
Key concepts
In this puzzle, you’ll learn about:
- Using
TileTensorfor 2D array access - Direct 2D indexing with
tensor[i, j] - Handling bounds checking with
TileTensor
The key insight is that TileTensor provides a natural 2D indexing interface,
abstracting away the underlying memory layout while still requiring bounds
checking.
- 2D access: Natural \((i,j)\) indexing with
TileTensor - Memory abstraction: No manual row-major calculation needed
- Guard condition: Still need bounds checking in both dimensions
- Thread bounds: More threads \((3 \times 3)\) than tensor elements \((2 \times 2)\)
Code to complete
comptime SIZE = 2
comptime BLOCKS_PER_GRID = 1
comptime THREADS_PER_BLOCK = (3, 3)
comptime dtype = DType.float32
comptime layout = row_major[SIZE, SIZE]()
comptime LayoutType = type_of(layout)
def add_10_2d(
output: TileTensor[mut=True, dtype, LayoutType, MutAnyOrigin],
a: TileTensor[mut=True, dtype, LayoutType, MutAnyOrigin],
size: Int,
):
var row = thread_idx.y
var col = thread_idx.x
# FILL ME IN (roughly 2 lines)
View full file: problems/p04/p04_tile_tensor.mojo
Tips
- Get 2D indices:
row = thread_idx.y,col = thread_idx.x - Add guard:
if row < size and col < size - Inside guard add 10 to
a[row, col]
Running the code
To test your solution, run the following command in your terminal:
pixi run p04_tile_tensor
pixi run -e amd p04_tile_tensor
pixi run -e apple p04_tile_tensor
uv run poe p04_tile_tensor
Your output will look like this if the puzzle isn’t solved yet:
out: HostBuffer([0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([10.0, 11.0, 12.0, 13.0])
Solution
def add_10_2d(
output: TileTensor[mut=True, dtype, LayoutType, MutAnyOrigin],
a: TileTensor[mut=True, dtype, LayoutType, MutAnyOrigin],
size: Int,
):
var row = thread_idx.y
var col = thread_idx.x
if col < size and row < size:
output[row, col] = a[row, col] + 10.0
This solution:
- Gets 2D thread indices with
row = thread_idx.y,col = thread_idx.x - Guards against out-of-bounds with
if row < size and col < size - Uses
TileTensor’s 2D indexing:output[row, col] = a[row, col] + 10.0