๊ธฐ๋ณธ ๋ฒ„์ „

1D LayoutTensor a์— ๋Œ€ํ•ด ๋ˆ„์  ํ•ฉ์„ ๊ณ„์‚ฐํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ 1D LayoutTensor output์— ์ €์žฅํ•˜๋Š” ์ปค๋„์„ ๊ตฌํ˜„ํ•˜์„ธ์š”.

์ฐธ๊ณ : a์˜ ํฌ๊ธฐ๊ฐ€ ๋ธ”๋ก ํฌ๊ธฐ๋ณด๋‹ค ํฐ ๊ฒฝ์šฐ, ๊ฐ ๋ธ”๋ก์˜ ํ•ฉ๊ณ„๋งŒ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.

๊ตฌ์„ฑ

  • ๋ฐฐ์—ด ํฌ๊ธฐ: SIZE = 8
  • ๋ธ”๋ก๋‹น ์Šค๋ ˆ๋“œ ์ˆ˜: TPB = 8
  • ๋ธ”๋ก ์ˆ˜: 1
  • ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ: TPB๊ฐœ ์›์†Œ

์ฐธ๊ณ :

  • ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ: ๊ฐ ์Šค๋ ˆ๋“œ๊ฐ€ LayoutTensor ์ ‘๊ทผ์„ ํ†ตํ•ด ์›์†Œ ํ•˜๋‚˜๋ฅผ ๋กœ๋“œ
  • ๋ฉ”๋ชจ๋ฆฌ ํŒจํ„ด: address_space๋ฅผ ์ง€์ •ํ•œ LayoutTensor๋กœ ์ค‘๊ฐ„ ๊ฒฐ๊ณผ๋ฅผ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ์— ์ €์žฅ
  • ์Šค๋ ˆ๋“œ ๋™๊ธฐํ™”: ์—ฐ์‚ฐ ๋‹จ๊ณ„ ๊ฐ„ ์กฐ์œจ
  • ์ ‘๊ทผ ํŒจํ„ด: ์ŠคํŠธ๋ผ์ด๋“œ ๊ธฐ๋ฐ˜ ๋ณ‘๋ ฌ ์—ฐ์‚ฐ
  • ํƒ€์ž… ์•ˆ์ „์„ฑ: LayoutTensor์˜ ํƒ€์ž… ์‹œ์Šคํ…œ ํ™œ์šฉ

์™„์„ฑํ•  ์ฝ”๋“œ

comptime TPB = 8
comptime SIZE = 8
comptime BLOCKS_PER_GRID = (1, 1)
comptime THREADS_PER_BLOCK = (TPB, 1)
comptime dtype = DType.float32
comptime layout = Layout.row_major(SIZE)


fn prefix_sum_simple[
    layout: Layout
](
    output: LayoutTensor[dtype, layout, MutAnyOrigin],
    a: LayoutTensor[dtype, layout, ImmutAnyOrigin],
    size: UInt,
):
    global_i = block_dim.x * block_idx.x + thread_idx.x
    local_i = thread_idx.x
    # FILL ME IN (roughly 18 lines)


์ „์ฒด ํŒŒ์ผ ๋ณด๊ธฐ: problems/p14/p14.mojo

ํŒ
  1. ๋ฐ์ดํ„ฐ๋ฅผ shared[local_i]์— ๋กœ๋“œ
  2. offset = 1์—์„œ ์‹œ์ž‘ํ•ด ๋งค ๋‹จ๊ณ„๋งˆ๋‹ค 2๋ฐฐ๋กœ ์ฆ๊ฐ€
  3. local_i >= offset์ธ ์›์†Œ์— ๋Œ€ํ•ด ๋ง์…ˆ ์ˆ˜ํ–‰
  4. ๊ฐ ๋‹จ๊ณ„ ์‚ฌ์ด์— barrier() ํ˜ธ์ถœ

์ฝ”๋“œ ์‹คํ–‰

์†”๋ฃจ์…˜์„ ํ…Œ์ŠคํŠธํ•˜๋ ค๋ฉด ํ„ฐ๋ฏธ๋„์—์„œ ๋‹ค์Œ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•˜์„ธ์š”:

pixi run p14 --simple
pixi run -e amd p14 --simple
pixi run -e apple p14 --simple
uv run poe p14 --simple

ํผ์ฆ์„ ์•„์ง ํ’€์ง€ ์•Š์•˜๋‹ค๋ฉด ์ถœ๋ ฅ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

out: DeviceBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([0.0, 1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0])

์†”๋ฃจ์…˜

fn prefix_sum_simple[
    layout: Layout
](
    output: LayoutTensor[dtype, layout, MutAnyOrigin],
    a: LayoutTensor[dtype, layout, ImmutAnyOrigin],
    size: UInt,
):
    global_i = block_dim.x * block_idx.x + thread_idx.x
    local_i = thread_idx.x
    shared = LayoutTensor[
        dtype,
        Layout.row_major(TPB),
        MutAnyOrigin,
        address_space = AddressSpace.SHARED,
    ].stack_allocation()
    if global_i < size:
        shared[local_i] = a[global_i]

    barrier()

    offset = UInt(1)
    for i in range(Int(log2(Scalar[dtype](TPB)))):
        var current_val: output.element_type = 0
        if local_i >= offset and local_i < size:
            current_val = shared[local_i - offset]  # read

        barrier()
        if local_i >= offset and local_i < size:
            shared[local_i] += current_val

        barrier()
        offset *= 2

    if global_i < size:
        output[global_i] = shared[local_i]


๋ณ‘๋ ฌ (ํฌํ•จ) ๋ˆ„์  ํ•ฉ ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค:

์„ค์ • ๋ฐ ๊ตฌ์„ฑ

  • TPB (๋ธ”๋ก๋‹น ์Šค๋ ˆ๋“œ ์ˆ˜) = 8
  • SIZE (๋ฐฐ์—ด ํฌ๊ธฐ) = 8

๊ฒฝ์Ÿ ์ƒํƒœ ๋ฐฉ์ง€

์ด ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ๋ช…์‹œ์  ๋™๊ธฐํ™”๋ฅผ ํ†ตํ•ด ์ฝ๊ธฐ-์“ฐ๊ธฐ ์ถฉ๋Œ์„ ๋ฐฉ์ง€ํ•ฉ๋‹ˆ๋‹ค:

  • ์ฝ๊ธฐ ๋‹จ๊ณ„: ๋ชจ๋“  ์Šค๋ ˆ๋“œ๊ฐ€ ๋จผ์ € ํ•„์š”ํ•œ ๊ฐ’์„ ๋กœ์ปฌ ๋ณ€์ˆ˜ current_val์— ์ฝ์–ด๋‘ 
  • ๋™๊ธฐํ™”: barrier()๋กœ ๋ชจ๋“  ์ฝ๊ธฐ๊ฐ€ ์™„๋ฃŒ๋œ ํ›„์—์•ผ ์“ฐ๊ธฐ๊ฐ€ ์‹œ์ž‘๋˜๋„๋ก ๋ณด์žฅ
  • ์“ฐ๊ธฐ ๋‹จ๊ณ„: ๋ชจ๋“  ์Šค๋ ˆ๋“œ๊ฐ€ ๊ณ„์‚ฐ๋œ ๊ฐ’์„ ์•ˆ์ „ํ•˜๊ฒŒ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ์— ๊ธฐ๋ก

์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ์—ฌ๋Ÿฌ ์Šค๋ ˆ๋“œ๊ฐ€ ๋™์‹œ์— ๊ฐ™์€ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ์œ„์น˜๋ฅผ ์ฝ๊ณ  ์“ธ ๋•Œ ๋ฐœ์ƒํ•˜๋Š” ๊ฒฝ์Ÿ ์ƒํƒœ๋ฅผ ๋ฐฉ์ง€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.ใ„นใ…‡

๋Œ€์•ˆ์  ์ ‘๊ทผ: ๊ฒฝ์Ÿ ์ƒํƒœ๋ฅผ ๋ฐฉ์ง€ํ•˜๋Š” ๋˜ ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•์€ ๋”๋ธ” ๋ฒ„ํผ๋ง ์ž…๋‹ˆ๋‹ค. ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ๋ฅผ 2๋ฐฐ๋กœ ํ• ๋‹นํ•œ ๋’ค, ํ•œ ๋ฒ„ํผ์—์„œ ์ฝ๊ณ  ๋‹ค๋ฅธ ๋ฒ„ํผ์— ์“ฐ๋Š” ๊ฒƒ์„ ๋ฒˆ๊ฐˆ์•„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐฉ์‹์ž…๋‹ˆ๋‹ค. ์ด ๋ฐฉ๋ฒ•์€ ๊ฒฝ์Ÿ ์ƒํƒœ๋ฅผ ์™„์ „ํžˆ ์ œ๊ฑฐํ•˜์ง€๋งŒ, ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์ด ๋Š˜์–ด๋‚˜๊ณ  ๋ณต์žก๋„๊ฐ€ ์˜ฌ๋ผ๊ฐ‘๋‹ˆ๋‹ค. ํ•™์Šต ๋ชฉ์ ์œผ๋กœ๋Š” ์ดํ•ดํ•˜๊ธฐ ๋” ์‰ฌ์šด ๋ช…์‹œ์  ๋™๊ธฐํ™” ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

์Šค๋ ˆ๋“œ ๋งคํ•‘

  • thread_idx.x: \([0, 1, 2, 3, 4, 5, 6, 7]\) (local_i)
  • block_idx.x: \([0, 0, 0, 0, 0, 0, 0, 0]\)
  • global_i: \([0, 1, 2, 3, 4, 5, 6, 7]\) (block_idx.x * TPB + thread_idx.x)

๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ์— ์ดˆ๊ธฐ ๋กœ๋“œ

Threads:      Tโ‚€   Tโ‚   Tโ‚‚   Tโ‚ƒ   Tโ‚„   Tโ‚…   Tโ‚†   Tโ‚‡
Input array:  [0    1    2    3    4    5    6    7]
shared:       [0    1    2    3    4    5    6    7]
               โ†‘    โ†‘    โ†‘    โ†‘    โ†‘    โ†‘    โ†‘    โ†‘
              Tโ‚€   Tโ‚   Tโ‚‚   Tโ‚ƒ   Tโ‚„   Tโ‚…   Tโ‚†   Tโ‚‡

Offset = 1: ์ฒซ ๋ฒˆ์งธ ๋ณ‘๋ ฌ ๋‹จ๊ณ„

ํ™œ์„ฑ ์Šค๋ ˆ๋“œ: \(T_1 \ldots T_7\) (local_i โ‰ฅ 1์ธ ์Šค๋ ˆ๋“œ)

์ฝ๊ธฐ ๋‹จ๊ณ„: ๊ฐ ์Šค๋ ˆ๋“œ๊ฐ€ ํ•„์š”ํ•œ ๊ฐ’์„ ์ฝ์Œ:

Tโ‚ reads shared[0] = 0    Tโ‚… reads shared[4] = 4
Tโ‚‚ reads shared[1] = 1    Tโ‚† reads shared[5] = 5
Tโ‚ƒ reads shared[2] = 2    Tโ‚‡ reads shared[6] = 6
Tโ‚„ reads shared[3] = 3

๋™๊ธฐํ™”: barrier()๋กœ ๋ชจ๋“  ์ฝ๊ธฐ ์™„๋ฃŒ๋ฅผ ๋ณด์žฅ

์“ฐ๊ธฐ ๋‹จ๊ณ„: ๊ฐ ์Šค๋ ˆ๋“œ๊ฐ€ ์ฝ์€ ๊ฐ’์„ ํ˜„์žฌ ์œ„์น˜์— ๋”ํ•จ:

Before:      [0    1    2    3    4    5    6    7]
Add:              +0   +1   +2   +3   +4   +5   +6
                   |    |    |    |    |    |    |
Result:      [0    1    3    5    7    9    11   13]
                   โ†‘    โ†‘    โ†‘    โ†‘    โ†‘    โ†‘    โ†‘
                  Tโ‚   Tโ‚‚   Tโ‚ƒ   Tโ‚„   Tโ‚…   Tโ‚†   Tโ‚‡

Offset = 2: ๋‘ ๋ฒˆ์งธ ๋ณ‘๋ ฌ ๋‹จ๊ณ„

ํ™œ์„ฑ ์Šค๋ ˆ๋“œ: \(T_2 \ldots T_7\) (local_i โ‰ฅ 2์ธ ์Šค๋ ˆ๋“œ)

์ฝ๊ธฐ ๋‹จ๊ณ„: ๊ฐ ์Šค๋ ˆ๋“œ๊ฐ€ ํ•„์š”ํ•œ ๊ฐ’์„ ์ฝ์Œ:

Tโ‚‚ reads shared[0] = 0    Tโ‚… reads shared[3] = 5
Tโ‚ƒ reads shared[1] = 1    Tโ‚† reads shared[4] = 7
Tโ‚„ reads shared[2] = 3    Tโ‚‡ reads shared[5] = 9

๋™๊ธฐํ™”: barrier()๋กœ ๋ชจ๋“  ์ฝ๊ธฐ ์™„๋ฃŒ๋ฅผ ๋ณด์žฅ

์“ฐ๊ธฐ ๋‹จ๊ณ„: ๊ฐ ์Šค๋ ˆ๋“œ๊ฐ€ ์ฝ์€ ๊ฐ’์„ ๋”ํ•จ:

Before:      [0    1    3    5    7    9    11   13]
Add:                   +0   +1   +3   +5   +7   +9
                        |    |    |    |    |    |
Result:      [0    1    3    6    10   14   18   22]
                        โ†‘    โ†‘    โ†‘    โ†‘    โ†‘    โ†‘
                       Tโ‚‚   Tโ‚ƒ   Tโ‚„   Tโ‚…   Tโ‚†   Tโ‚‡

Offset = 4: ์„ธ ๋ฒˆ์งธ ๋ณ‘๋ ฌ ๋‹จ๊ณ„

ํ™œ์„ฑ ์Šค๋ ˆ๋“œ: \(T_4 \ldots T_7\) (local_i โ‰ฅ 4์ธ ์Šค๋ ˆ๋“œ)

์ฝ๊ธฐ ๋‹จ๊ณ„: ๊ฐ ์Šค๋ ˆ๋“œ๊ฐ€ ํ•„์š”ํ•œ ๊ฐ’์„ ์ฝ์Œ:

Tโ‚„ reads shared[0] = 0    Tโ‚† reads shared[2] = 3
Tโ‚… reads shared[1] = 1    Tโ‚‡ reads shared[3] = 6

๋™๊ธฐํ™”: barrier()๋กœ ๋ชจ๋“  ์ฝ๊ธฐ ์™„๋ฃŒ๋ฅผ ๋ณด์žฅ

์“ฐ๊ธฐ ๋‹จ๊ณ„: ๊ฐ ์Šค๋ ˆ๋“œ๊ฐ€ ์ฝ์€ ๊ฐ’์„ ๋”ํ•จ:

Before:      [0    1    3    6    10   14   18   22]
Add:                              +0   +1   +3   +6
                                  |    |    |    |
Result:      [0    1    3    6    10   15   21   28]
                                  โ†‘    โ†‘    โ†‘    โ†‘
                                  Tโ‚„   Tโ‚…   Tโ‚†   Tโ‚‡

์ตœ์ข… ๊ฒฐ๊ณผ๋ฅผ output์— ๊ธฐ๋ก

Threads:      Tโ‚€   Tโ‚   Tโ‚‚   Tโ‚ƒ   Tโ‚„   Tโ‚…   Tโ‚†   Tโ‚‡
global_i:     0    1    2    3    4    5    6    7
output:       [0    1    3    6    10   15   21   28]
              โ†‘    โ†‘    โ†‘    โ†‘    โ†‘    โ†‘    โ†‘    โ†‘
              Tโ‚€   Tโ‚   Tโ‚‚   Tโ‚ƒ   Tโ‚„   Tโ‚…   Tโ‚†   Tโ‚‡

์ฃผ์š” ๊ตฌํ˜„ ์ƒ์„ธ

๋™๊ธฐํ™” ํŒจํ„ด: ๊ฐ ๋ฐ˜๋ณต์€ ์—„๊ฒฉํ•œ ์ฝ๊ธฐ โ†’ ๋™๊ธฐํ™” โ†’ ์“ฐ๊ธฐ ํŒจํ„ด์„ ๋”ฐ๋ฆ…๋‹ˆ๋‹ค:

  1. var current_val: out.element_type = 0 - ๋กœ์ปฌ ๋ณ€์ˆ˜ ์ดˆ๊ธฐํ™”
  2. current_val = shared[local_i - offset] - ์ฝ๊ธฐ ๋‹จ๊ณ„ (์กฐ๊ฑด ์ถฉ์กฑ ์‹œ)
  3. barrier() - ๊ฒฝ์Ÿ ์ƒํƒœ ๋ฐฉ์ง€๋ฅผ ์œ„ํ•œ ๋ช…์‹œ์  ๋™๊ธฐํ™”
  4. shared[local_i] += current_val - ์“ฐ๊ธฐ ๋‹จ๊ณ„ (์กฐ๊ฑด ์ถฉ์กฑ ์‹œ)
  5. barrier() - ๋‹ค์Œ ๋ฐ˜๋ณต ์ „ ๋™๊ธฐํ™”

๊ฒฝ์Ÿ ์ƒํƒœ ๋ฐฉ์ง€: ์ฝ๊ธฐ์™€ ์“ฐ๊ธฐ๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ๋ถ„๋ฆฌํ•˜์ง€ ์•Š์œผ๋ฉด ์—ฌ๋Ÿฌ ์Šค๋ ˆ๋“œ๊ฐ€ ๋™์‹œ์— ๊ฐ™์€ ๊ณต์œ  ๋ฉ”๋ชจ๋ฆฌ ์œ„์น˜์— ์ ‘๊ทผํ•˜์—ฌ ๋ฏธ์ •์˜ ๋™์ž‘์ด ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ช…์‹œ์  ๋™๊ธฐํ™”๋ฅผ ์‚ฌ์šฉํ•œ 2๋‹จ๊ณ„ ์ ‘๊ทผ ๋ฐฉ์‹์ด ์ •ํ™•์„ฑ์„ ๋ณด์žฅํ•ฉ๋‹ˆ๋‹ค.

๋ฉ”๋ชจ๋ฆฌ ์•ˆ์ „์„ฑ: ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ๋‹ค์Œ์„ ํ†ตํ•ด ๋ฉ”๋ชจ๋ฆฌ ์•ˆ์ „์„ฑ์„ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค:

  • if local_i >= offset and local_i < size๋กœ ๊ฒฝ๊ณ„ ๊ฒ€์‚ฌ
  • ์ž„์‹œ ๋ณ€์ˆ˜์˜ ์ ์ ˆํ•œ ์ดˆ๊ธฐํ™”
  • ๊ฒฝ์Ÿ ์ƒํƒœ๋ฅผ ๋ฐฉ์ง€ํ•˜๋Š” ์กฐ์œจ๋œ ์ ‘๊ทผ ํŒจํ„ด

์ด ์†”๋ฃจ์…˜์€ barrier()๋ฅผ ์‚ฌ์šฉํ•ด ๋‹จ๊ณ„ ๊ฐ„ ์˜ฌ๋ฐ”๋ฅธ ๋™๊ธฐํ™”๋ฅผ ๋ณด์žฅํ•˜๊ณ , if global_i < size๋กœ ๋ฐฐ์—ด ๊ฒฝ๊ณ„ ๊ฒ€์‚ฌ๋ฅผ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค. ์ตœ์ข… ๊ฒฐ๊ณผ๋Š” ๊ฐ ์›์†Œ \(i\)๊ฐ€ \(\sum_{j=0}^{i} a[j]\) ๋ฅผ ํฌํ•จํ•˜๋Š” ํฌํ•จ ๋ˆ„์  ํ•ฉ์ž…๋‹ˆ๋‹ค.