Puzzle 15: ์ถ• ํ•ฉ๊ณ„

๊ฐœ์š”

2D ํ–‰๋ ฌ a์˜ ๊ฐ ํ–‰์— ๋Œ€ํ•ด ํ•ฉ๊ณ„๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ LayoutTensor๋ฅผ ์‚ฌ์šฉํ•ด output์— ์ €์žฅํ•˜๋Š” ์ปค๋„์„ ๊ตฌํ˜„ํ•˜์„ธ์š”.

์ถ• ํ•ฉ๊ณ„ ์‹œ๊ฐํ™” ์ถ• ํ•ฉ๊ณ„ ์‹œ๊ฐํ™”

ํ•ต์‹ฌ ๊ฐœ๋…

์ด ํผ์ฆ์—์„œ ๋‹ค๋ฃจ๋Š” ๋‚ด์šฉ:

  • LayoutTensor๋ฅผ ํ™œ์šฉํ•œ ํ–‰๋ ฌ ์ฐจ์› ๋ฐฉํ–ฅ์˜ ๋ณ‘๋ ฌ ๋ฆฌ๋•์…˜
  • ๋ธ”๋ก ์ขŒํ‘œ๋ฅผ ์ด์šฉํ•œ ๋ฐ์ดํ„ฐ ๋ถ„ํ• 
  • ํšจ์œจ์ ์ธ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ๋ฆฌ๋•์…˜ ํŒจํ„ด
  • ๋‹ค์ฐจ์› ํ…์„œ ๋ ˆ์ด์•„์›ƒ ๋‹ค๋ฃจ๊ธฐ

ํ•ต์‹ฌ์€ ์Šค๋ ˆ๋“œ ๋ธ”๋ก์„ ํ–‰๋ ฌ์˜ ํ–‰์— ๋งคํ•‘ํ•˜๊ณ , LayoutTensor์˜ ์ฐจ์›๋ณ„ ์ธ๋ฑ์‹ฑ์„ ํ™œ์šฉํ•˜๋ฉด์„œ ๊ฐ ๋ธ”๋ก ๋‚ด์—์„œ ํšจ์œจ์ ์ธ ๋ณ‘๋ ฌ ๋ฆฌ๋•์…˜์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ดํ•ดํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๊ตฌ์„ฑ

  • ํ–‰๋ ฌ ํฌ๊ธฐ: \(\text{BATCH} \times \text{SIZE} = 4 \times 6\)
  • ๋ธ”๋ก๋‹น ์Šค๋ ˆ๋“œ ์ˆ˜: \(\text{TPB} = 8\)
  • ๊ทธ๋ฆฌ๋“œ ํฌ๊ธฐ: \(1 \times \text{BATCH}\)
  • ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ: ๋ธ”๋ก๋‹น \(\text{TPB}\)๊ฐœ ์›์†Œ
  • ์ž…๋ ฅ ๋ ˆ์ด์•„์›ƒ: Layout.row_major(BATCH, SIZE)
  • ์ถœ๋ ฅ ๋ ˆ์ด์•„์›ƒ: Layout.row_major(BATCH, 1)

ํ–‰๋ ฌ ์‹œ๊ฐํ™”:

Row 0: [0, 1, 2, 3, 4, 5]       โ†’ Block(0,0)
Row 1: [6, 7, 8, 9, 10, 11]     โ†’ Block(0,1)
Row 2: [12, 13, 14, 15, 16, 17] โ†’ Block(0,2)
Row 3: [18, 19, 20, 21, 22, 23] โ†’ Block(0,3)

์™„์„ฑํ•  ์ฝ”๋“œ

from gpu import thread_idx, block_idx, block_dim, barrier
from gpu.memory import AddressSpace
from layout import Layout, LayoutTensor


comptime TPB = 8
comptime BATCH = 4
comptime SIZE = 6
comptime BLOCKS_PER_GRID = (1, BATCH)
comptime THREADS_PER_BLOCK = (TPB, 1)
comptime dtype = DType.float32
comptime in_layout = Layout.row_major(BATCH, SIZE)
comptime out_layout = Layout.row_major(BATCH, 1)


fn axis_sum[
    in_layout: Layout, out_layout: Layout
](
    output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
    a: LayoutTensor[dtype, in_layout, ImmutAnyOrigin],
    size: UInt,
):
    global_i = block_dim.x * block_idx.x + thread_idx.x
    local_i = thread_idx.x
    batch = block_idx.y
    # FILL ME IN (roughly 15 lines)


์ „์ฒด ํŒŒ์ผ ๋ณด๊ธฐ: problems/p15/p15.mojo

ํŒ
  1. batch = block_idx.y๋กœ ํ–‰ ์„ ํƒ
  2. ์›์†Œ ๋กœ๋“œ: cache[local_i] = a[batch, local_i]
  3. ์ŠคํŠธ๋ผ์ด๋“œ๋ฅผ ์ ˆ๋ฐ˜์”ฉ ์ค„์ด๋ฉฐ ๋ณ‘๋ ฌ ๋ฆฌ๋•์…˜ ์ˆ˜ํ–‰
  4. ์Šค๋ ˆ๋“œ 0์ด ์ตœ์ข… ํ•ฉ๊ณ„๋ฅผ output[batch]์— ๊ธฐ๋ก

์ฝ”๋“œ ์‹คํ–‰

์†”๋ฃจ์…˜์„ ํ…Œ์ŠคํŠธํ•˜๋ ค๋ฉด ํ„ฐ๋ฏธ๋„์—์„œ ๋‹ค์Œ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•˜์„ธ์š”:

pixi run p15
pixi run -e amd p15
pixi run -e apple p15
uv run poe p15

ํผ์ฆ์„ ์•„์ง ํ’€์ง€ ์•Š์•˜๋‹ค๋ฉด ์ถœ๋ ฅ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

out: DeviceBuffer([0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([15.0, 51.0, 87.0, 123.0])

์†”๋ฃจ์…˜

fn axis_sum[
    in_layout: Layout, out_layout: Layout
](
    output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
    a: LayoutTensor[dtype, in_layout, ImmutAnyOrigin],
    size: UInt,
):
    global_i = block_dim.x * block_idx.x + thread_idx.x
    local_i = thread_idx.x
    batch = block_idx.y
    cache = LayoutTensor[
        dtype,
        Layout.row_major(TPB),
        MutAnyOrigin,
        address_space = AddressSpace.SHARED,
    ].stack_allocation()

    # Visualize:
    # Block(0,0): [T0,T1,T2,T3,T4,T5,T6,T7] -> Row 0: [0,1,2,3,4,5]
    # Block(0,1): [T0,T1,T2,T3,T4,T5,T6,T7] -> Row 1: [6,7,8,9,10,11]
    # Block(0,2): [T0,T1,T2,T3,T4,T5,T6,T7] -> Row 2: [12,13,14,15,16,17]
    # Block(0,3): [T0,T1,T2,T3,T4,T5,T6,T7] -> Row 3: [18,19,20,21,22,23]

    # each row is handled by each block bc we have grid_dim=(1, BATCH)

    if local_i < size:
        cache[local_i] = a[batch, local_i]
    else:
        # Add zero-initialize padding elements for later reduction
        cache[local_i] = 0

    barrier()

    # do reduction sum per each block
    stride = UInt(TPB // 2)
    while stride > 0:
        # Read phase: all threads read the values they need first to avoid race conditions
        var temp_val: output.element_type = 0
        if local_i < stride:
            temp_val = cache[local_i + stride]

        barrier()

        # Write phase: all threads safely write their computed values
        if local_i < stride:
            cache[local_i] += temp_val

        barrier()
        stride //= 2

    # writing with local thread = 0 that has the sum for each batch
    if local_i == 0:
        output[batch, 0] = cache[0]


LayoutTensor๋ฅผ ํ™œ์šฉํ•ด 2D ํ–‰๋ ฌ์˜ ํ–‰ ๋ฐฉํ–ฅ ํ•ฉ๊ณ„๋ฅผ ๋ณ‘๋ ฌ๋กœ ๊ตฌํ•˜๋Š” ๋ฆฌ๋•์…˜ ๊ตฌํ˜„์ž…๋‹ˆ๋‹ค. ๋‹จ๊ณ„๋ณ„๋กœ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

ํ–‰๋ ฌ ๋ ˆ์ด์•„์›ƒ๊ณผ ๋ธ”๋ก ๋งคํ•‘

Input Matrix (4ร—6) with LayoutTensor:                Block Assignment:
[[ a[0,0]  a[0,1]  a[0,2]  a[0,3]  a[0,4]  a[0,5] ] โ†’ Block(0,0)
 [ a[1,0]  a[1,1]  a[1,2]  a[1,3]  a[1,4]  a[1,5] ] โ†’ Block(0,1)
 [ a[2,0]  a[2,1]  a[2,2]  a[2,3]  a[2,4]  a[2,5] ] โ†’ Block(0,2)
 [ a[3,0]  a[3,1]  a[3,2]  a[3,3]  a[3,4]  a[3,5] ] โ†’ Block(0,3)

๋ณ‘๋ ฌ ๋ฆฌ๋•์…˜ ๊ณผ์ •

  1. ์ดˆ๊ธฐ ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ:

    Block(0,0): cache = [a[0,0] a[0,1] a[0,2] a[0,3] a[0,4] a[0,5] * *]  // * = ํŒจ๋”ฉ
    Block(0,1): cache = [a[1,0] a[1,1] a[1,2] a[1,3] a[1,4] a[1,5] * *]
    Block(0,2): cache = [a[2,0] a[2,1] a[2,2] a[2,3] a[2,4] a[2,5] * *]
    Block(0,3): cache = [a[3,0] a[3,1] a[3,2] a[3,3] a[3,4] a[3,5] * *]
    
  2. ๋ฆฌ๋•์…˜ ๋‹จ๊ณ„ (Block 0,0 ๊ธฐ์ค€):

    Initial:  [0  1  2  3  4  5  *  *]
    Stride 4: [4  5  6  7  4  5  *  *]
    Stride 2: [10 12 6  7  4  5  *  *]
    Stride 1: [15 12 6  7  4  5  *  *]
    

์ฃผ์š” ๊ตฌํ˜„ ํŠน์ง•

  1. ๋ ˆ์ด์•„์›ƒ ๊ตฌ์„ฑ:

    • ์ž…๋ ฅ: ํ–‰ ์šฐ์„ (row-major) ๋ ˆ์ด์•„์›ƒ (BATCH ร— SIZE)
    • ์ถœ๋ ฅ: ํ–‰ ์šฐ์„  ๋ ˆ์ด์•„์›ƒ (BATCH ร— 1)
    • ๊ฐ ๋ธ”๋ก์ด ํ•˜๋‚˜์˜ ํ–‰ ์ „์ฒด๋ฅผ ์ฒ˜๋ฆฌ
  2. ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ ํŒจํ„ด:

    • ์ž…๋ ฅ์— LayoutTensor 2D ์ธ๋ฑ์‹ฑ ์‚ฌ์šฉ: a[batch, local_i]
    • ํšจ์œจ์ ์ธ ๋ฆฌ๋•์…˜์„ ์œ„ํ•œ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ํ™œ์šฉ
    • ์ถœ๋ ฅ์— LayoutTensor 2D ์ธ๋ฑ์‹ฑ ์‚ฌ์šฉ: output[batch, 0]
  3. ๋ณ‘๋ ฌ ๋ฆฌ๋•์…˜ ๋กœ์ง:

    stride = TPB // 2
    while stride > 0:
        if local_i < stride:
            cache[local_i] += cache[local_i + stride]
        barrier()
        stride //= 2
    

    ์ฐธ๊ณ : ์ด ๊ตฌํ˜„์—์„œ๋Š” ๊ฐ™์€ ๋ฐ˜๋ณต ๋‚ด์—์„œ ์Šค๋ ˆ๋“œ๋“ค์ด ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋™์‹œ์— ์ฝ๊ณ  ์“ฐ๊ธฐ ๋•Œ๋ฌธ์— ์ž ์žฌ์ ์ธ ๊ฒฝ์Ÿ ์ƒํƒœ๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋” ์•ˆ์ „ํ•œ ๋ฐฉ๋ฒ•์€ ์ฝ๊ธฐ์™€ ์“ฐ๊ธฐ ๋‹จ๊ณ„๋ฅผ ๋ถ„๋ฆฌํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค:

    stride = TPB // 2
    while stride > 0:
        var temp_val: output.element_type = 0
        if local_i < stride:
            temp_val = cache[local_i + stride]  # ์ฝ๊ธฐ ๋‹จ๊ณ„
        barrier()
        if local_i < stride:
            cache[local_i] += temp_val  # ์“ฐ๊ธฐ ๋‹จ๊ณ„
        barrier()
        stride //= 2
    
  4. ์ถœ๋ ฅ ๊ธฐ๋ก:

    if local_i == 0:
        output[batch, 0] = cache[0]  --> ๋ฐฐ์น˜๋‹น ๊ฒฐ๊ณผ ํ•˜๋‚˜
    

์„ฑ๋Šฅ ์ตœ์ ํ™”

  1. ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ:

    • LayoutTensor๋ฅผ ํ†ตํ•œ ๋ณ‘ํ•ฉ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ
    • ๋น ๋ฅธ ๋ฆฌ๋•์…˜์„ ์œ„ํ•œ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ํ™œ์šฉ
    • ํ–‰ ๊ฒฐ๊ณผ๋‹น ํ•œ ๋ฒˆ์˜ ์“ฐ๊ธฐ
  2. ์Šค๋ ˆ๋“œ ํ™œ์šฉ:

    • ํ–‰ ๊ฐ„ ์™„๋ฒฝํ•œ ๋ถ€ํ•˜ ๊ท ํ˜•
    • ์ฃผ์š” ์—ฐ์‚ฐ์—์„œ ์Šค๋ ˆ๋“œ ๋ถ„๊ธฐ ์—†์Œ
    • ํšจ์œจ์ ์ธ ๋ณ‘๋ ฌ ๋ฆฌ๋•์…˜ ํŒจํ„ด
  3. ๋™๊ธฐํ™”:

    • ์ตœ์†Œํ•œ์˜ ๋ฐฐ๋ฆฌ์–ด (๋ฆฌ๋•์…˜ ์ค‘์—๋งŒ ์‚ฌ์šฉ)
    • ํ–‰ ๊ฐ„ ๋…๋ฆฝ์ ์ธ ์ฒ˜๋ฆฌ
    • ๋ธ”๋ก ๊ฐ„ ํ†ต์‹  ๋ถˆํ•„์š”
    • ๊ฒฝ์Ÿ ์ƒํƒœ ๊ณ ๋ ค์‚ฌํ•ญ: ํ˜„์žฌ ๊ตฌํ˜„์—์„œ๋Š” ๋ณ‘๋ ฌ ๋ฆฌ๋•์…˜ ์ค‘์— ์ฝ๊ธฐ-์“ฐ๊ธฐ ์ถฉ๋Œ์ด ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๋ช…์‹œ์ ์ธ ์ฝ๊ธฐ-์“ฐ๊ธฐ ๋‹จ๊ณ„ ๋ถ„๋ฆฌ๋กœ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค

๋ณต์žก๋„ ๋ถ„์„

  • ์‹œ๊ฐ„: ํ–‰๋‹น \(O(\log n)\), n์€ ํ–‰์˜ ๊ธธ์ด
  • ๊ณต๊ฐ„: ๋ธ”๋ก๋‹น \(O(TPB)\) ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ
  • ์ „์ฒด ๋ณ‘๋ ฌ ์‹œ๊ฐ„: ์Šค๋ ˆ๋“œ๊ฐ€ ์ถฉ๋ถ„ํ•  ๋•Œ \(O(\log n)\)