Puzzle 15: ์ถ ํฉ๊ณ
๊ฐ์
2D ํ๋ ฌ a์ ๊ฐ ํ์ ๋ํด ํฉ๊ณ๋ฅผ ๊ณ์ฐํ์ฌ LayoutTensor๋ฅผ ์ฌ์ฉํด output์ ์ ์ฅํ๋ ์ปค๋์ ๊ตฌํํ์ธ์.
ํต์ฌ ๊ฐ๋
์ด ํผ์ฆ์์ ๋ค๋ฃจ๋ ๋ด์ฉ:
- LayoutTensor๋ฅผ ํ์ฉํ ํ๋ ฌ ์ฐจ์ ๋ฐฉํฅ์ ๋ณ๋ ฌ ๋ฆฌ๋์
- ๋ธ๋ก ์ขํ๋ฅผ ์ด์ฉํ ๋ฐ์ดํฐ ๋ถํ
- ํจ์จ์ ์ธ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๋ฆฌ๋์ ํจํด
- ๋ค์ฐจ์ ํ ์ ๋ ์ด์์ ๋ค๋ฃจ๊ธฐ
ํต์ฌ์ ์ค๋ ๋ ๋ธ๋ก์ ํ๋ ฌ์ ํ์ ๋งคํํ๊ณ , LayoutTensor์ ์ฐจ์๋ณ ์ธ๋ฑ์ฑ์ ํ์ฉํ๋ฉด์ ๊ฐ ๋ธ๋ก ๋ด์์ ํจ์จ์ ์ธ ๋ณ๋ ฌ ๋ฆฌ๋์ ์ ์ํํ๋ ๋ฐฉ๋ฒ์ ์ดํดํ๋ ๊ฒ์ ๋๋ค.
๊ตฌ์ฑ
- ํ๋ ฌ ํฌ๊ธฐ: \(\text{BATCH} \times \text{SIZE} = 4 \times 6\)
- ๋ธ๋ก๋น ์ค๋ ๋ ์: \(\text{TPB} = 8\)
- ๊ทธ๋ฆฌ๋ ํฌ๊ธฐ: \(1 \times \text{BATCH}\)
- ๊ณต์ ๋ฉ๋ชจ๋ฆฌ: ๋ธ๋ก๋น \(\text{TPB}\)๊ฐ ์์
- ์
๋ ฅ ๋ ์ด์์:
Layout.row_major(BATCH, SIZE) - ์ถ๋ ฅ ๋ ์ด์์:
Layout.row_major(BATCH, 1)
ํ๋ ฌ ์๊ฐํ:
Row 0: [0, 1, 2, 3, 4, 5] โ Block(0,0)
Row 1: [6, 7, 8, 9, 10, 11] โ Block(0,1)
Row 2: [12, 13, 14, 15, 16, 17] โ Block(0,2)
Row 3: [18, 19, 20, 21, 22, 23] โ Block(0,3)
์์ฑํ ์ฝ๋
from gpu import thread_idx, block_idx, block_dim, barrier
from gpu.memory import AddressSpace
from layout import Layout, LayoutTensor
comptime TPB = 8
comptime BATCH = 4
comptime SIZE = 6
comptime BLOCKS_PER_GRID = (1, BATCH)
comptime THREADS_PER_BLOCK = (TPB, 1)
comptime dtype = DType.float32
comptime in_layout = Layout.row_major(BATCH, SIZE)
comptime out_layout = Layout.row_major(BATCH, 1)
fn axis_sum[
in_layout: Layout, out_layout: Layout
](
output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
a: LayoutTensor[dtype, in_layout, ImmutAnyOrigin],
size: UInt,
):
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
batch = block_idx.y
# FILL ME IN (roughly 15 lines)
์ ์ฒด ํ์ผ ๋ณด๊ธฐ: problems/p15/p15.mojo
ํ
batch = block_idx.y๋ก ํ ์ ํ- ์์ ๋ก๋:
cache[local_i] = a[batch, local_i] - ์คํธ๋ผ์ด๋๋ฅผ ์ ๋ฐ์ฉ ์ค์ด๋ฉฐ ๋ณ๋ ฌ ๋ฆฌ๋์ ์ํ
- ์ค๋ ๋ 0์ด ์ต์ข
ํฉ๊ณ๋ฅผ
output[batch]์ ๊ธฐ๋ก
์ฝ๋ ์คํ
์๋ฃจ์ ์ ํ ์คํธํ๋ ค๋ฉด ํฐ๋ฏธ๋์์ ๋ค์ ๋ช ๋ น์ด๋ฅผ ์คํํ์ธ์:
pixi run p15
pixi run -e amd p15
pixi run -e apple p15
uv run poe p15
ํผ์ฆ์ ์์ง ํ์ง ์์๋ค๋ฉด ์ถ๋ ฅ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
out: DeviceBuffer([0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([15.0, 51.0, 87.0, 123.0])
์๋ฃจ์
fn axis_sum[
in_layout: Layout, out_layout: Layout
](
output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
a: LayoutTensor[dtype, in_layout, ImmutAnyOrigin],
size: UInt,
):
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
batch = block_idx.y
cache = LayoutTensor[
dtype,
Layout.row_major(TPB),
MutAnyOrigin,
address_space = AddressSpace.SHARED,
].stack_allocation()
# Visualize:
# Block(0,0): [T0,T1,T2,T3,T4,T5,T6,T7] -> Row 0: [0,1,2,3,4,5]
# Block(0,1): [T0,T1,T2,T3,T4,T5,T6,T7] -> Row 1: [6,7,8,9,10,11]
# Block(0,2): [T0,T1,T2,T3,T4,T5,T6,T7] -> Row 2: [12,13,14,15,16,17]
# Block(0,3): [T0,T1,T2,T3,T4,T5,T6,T7] -> Row 3: [18,19,20,21,22,23]
# each row is handled by each block bc we have grid_dim=(1, BATCH)
if local_i < size:
cache[local_i] = a[batch, local_i]
else:
# Add zero-initialize padding elements for later reduction
cache[local_i] = 0
barrier()
# do reduction sum per each block
stride = UInt(TPB // 2)
while stride > 0:
# Read phase: all threads read the values they need first to avoid race conditions
var temp_val: output.element_type = 0
if local_i < stride:
temp_val = cache[local_i + stride]
barrier()
# Write phase: all threads safely write their computed values
if local_i < stride:
cache[local_i] += temp_val
barrier()
stride //= 2
# writing with local thread = 0 that has the sum for each batch
if local_i == 0:
output[batch, 0] = cache[0]
LayoutTensor๋ฅผ ํ์ฉํด 2D ํ๋ ฌ์ ํ ๋ฐฉํฅ ํฉ๊ณ๋ฅผ ๋ณ๋ ฌ๋ก ๊ตฌํ๋ ๋ฆฌ๋์ ๊ตฌํ์ ๋๋ค. ๋จ๊ณ๋ณ๋ก ์ดํด๋ณด๊ฒ ์ต๋๋ค:
ํ๋ ฌ ๋ ์ด์์๊ณผ ๋ธ๋ก ๋งคํ
Input Matrix (4ร6) with LayoutTensor: Block Assignment:
[[ a[0,0] a[0,1] a[0,2] a[0,3] a[0,4] a[0,5] ] โ Block(0,0)
[ a[1,0] a[1,1] a[1,2] a[1,3] a[1,4] a[1,5] ] โ Block(0,1)
[ a[2,0] a[2,1] a[2,2] a[2,3] a[2,4] a[2,5] ] โ Block(0,2)
[ a[3,0] a[3,1] a[3,2] a[3,3] a[3,4] a[3,5] ] โ Block(0,3)
๋ณ๋ ฌ ๋ฆฌ๋์ ๊ณผ์
-
์ด๊ธฐ ๋ฐ์ดํฐ ๋ก๋ฉ:
Block(0,0): cache = [a[0,0] a[0,1] a[0,2] a[0,3] a[0,4] a[0,5] * *] // * = ํจ๋ฉ Block(0,1): cache = [a[1,0] a[1,1] a[1,2] a[1,3] a[1,4] a[1,5] * *] Block(0,2): cache = [a[2,0] a[2,1] a[2,2] a[2,3] a[2,4] a[2,5] * *] Block(0,3): cache = [a[3,0] a[3,1] a[3,2] a[3,3] a[3,4] a[3,5] * *] -
๋ฆฌ๋์ ๋จ๊ณ (Block 0,0 ๊ธฐ์ค):
Initial: [0 1 2 3 4 5 * *] Stride 4: [4 5 6 7 4 5 * *] Stride 2: [10 12 6 7 4 5 * *] Stride 1: [15 12 6 7 4 5 * *]
์ฃผ์ ๊ตฌํ ํน์ง
-
๋ ์ด์์ ๊ตฌ์ฑ:
- ์ ๋ ฅ: ํ ์ฐ์ (row-major) ๋ ์ด์์ (BATCH ร SIZE)
- ์ถ๋ ฅ: ํ ์ฐ์ ๋ ์ด์์ (BATCH ร 1)
- ๊ฐ ๋ธ๋ก์ด ํ๋์ ํ ์ ์ฒด๋ฅผ ์ฒ๋ฆฌ
-
๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ ํจํด:
- ์
๋ ฅ์ LayoutTensor 2D ์ธ๋ฑ์ฑ ์ฌ์ฉ:
a[batch, local_i] - ํจ์จ์ ์ธ ๋ฆฌ๋์ ์ ์ํ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํ์ฉ
- ์ถ๋ ฅ์ LayoutTensor 2D ์ธ๋ฑ์ฑ ์ฌ์ฉ:
output[batch, 0]
- ์
๋ ฅ์ LayoutTensor 2D ์ธ๋ฑ์ฑ ์ฌ์ฉ:
-
๋ณ๋ ฌ ๋ฆฌ๋์ ๋ก์ง:
stride = TPB // 2 while stride > 0: if local_i < stride: cache[local_i] += cache[local_i + stride] barrier() stride //= 2์ฐธ๊ณ : ์ด ๊ตฌํ์์๋ ๊ฐ์ ๋ฐ๋ณต ๋ด์์ ์ค๋ ๋๋ค์ด ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋์์ ์ฝ๊ณ ์ฐ๊ธฐ ๋๋ฌธ์ ์ ์ฌ์ ์ธ ๊ฒฝ์ ์ํ๊ฐ ๋ฐ์ํ ์ ์์ต๋๋ค. ๋ ์์ ํ ๋ฐฉ๋ฒ์ ์ฝ๊ธฐ์ ์ฐ๊ธฐ ๋จ๊ณ๋ฅผ ๋ถ๋ฆฌํ๋ ๊ฒ์ ๋๋ค:
stride = TPB // 2 while stride > 0: var temp_val: output.element_type = 0 if local_i < stride: temp_val = cache[local_i + stride] # ์ฝ๊ธฐ ๋จ๊ณ barrier() if local_i < stride: cache[local_i] += temp_val # ์ฐ๊ธฐ ๋จ๊ณ barrier() stride //= 2 -
์ถ๋ ฅ ๊ธฐ๋ก:
if local_i == 0: output[batch, 0] = cache[0] --> ๋ฐฐ์น๋น ๊ฒฐ๊ณผ ํ๋
์ฑ๋ฅ ์ต์ ํ
-
๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ:
- LayoutTensor๋ฅผ ํตํ ๋ณํฉ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ
- ๋น ๋ฅธ ๋ฆฌ๋์ ์ ์ํ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํ์ฉ
- ํ ๊ฒฐ๊ณผ๋น ํ ๋ฒ์ ์ฐ๊ธฐ
-
์ค๋ ๋ ํ์ฉ:
- ํ ๊ฐ ์๋ฒฝํ ๋ถํ ๊ท ํ
- ์ฃผ์ ์ฐ์ฐ์์ ์ค๋ ๋ ๋ถ๊ธฐ ์์
- ํจ์จ์ ์ธ ๋ณ๋ ฌ ๋ฆฌ๋์ ํจํด
-
๋๊ธฐํ:
- ์ต์ํ์ ๋ฐฐ๋ฆฌ์ด (๋ฆฌ๋์ ์ค์๋ง ์ฌ์ฉ)
- ํ ๊ฐ ๋ ๋ฆฝ์ ์ธ ์ฒ๋ฆฌ
- ๋ธ๋ก ๊ฐ ํต์ ๋ถํ์
- ๊ฒฝ์ ์ํ ๊ณ ๋ ค์ฌํญ: ํ์ฌ ๊ตฌํ์์๋ ๋ณ๋ ฌ ๋ฆฌ๋์ ์ค์ ์ฝ๊ธฐ-์ฐ๊ธฐ ์ถฉ๋์ด ๋ฐ์ํ ์ ์์ผ๋ฉฐ, ๋ช ์์ ์ธ ์ฝ๊ธฐ-์ฐ๊ธฐ ๋จ๊ณ ๋ถ๋ฆฌ๋ก ํด๊ฒฐํ ์ ์์ต๋๋ค
๋ณต์ก๋ ๋ถ์
- ์๊ฐ: ํ๋น \(O(\log n)\), n์ ํ์ ๊ธธ์ด
- ๊ณต๊ฐ: ๋ธ๋ก๋น \(O(TPB)\) ๊ณต์ ๋ฉ๋ชจ๋ฆฌ
- ์ ์ฒด ๋ณ๋ ฌ ์๊ฐ: ์ค๋ ๋๊ฐ ์ถฉ๋ถํ ๋ \(O(\log n)\)