๊ธฐ๋ณธ ๋ฒ์
1D LayoutTensor a์ ๋ํด ๋์ ํฉ์ ๊ณ์ฐํ๊ณ ๊ฒฐ๊ณผ๋ฅผ 1D LayoutTensor output์ ์ ์ฅํ๋ ์ปค๋์ ๊ตฌํํ์ธ์.
์ฐธ๊ณ : a์ ํฌ๊ธฐ๊ฐ ๋ธ๋ก ํฌ๊ธฐ๋ณด๋ค ํฐ ๊ฒฝ์ฐ, ๊ฐ ๋ธ๋ก์ ํฉ๊ณ๋ง ์ ์ฅํฉ๋๋ค.
๊ตฌ์ฑ
- ๋ฐฐ์ด ํฌ๊ธฐ:
SIZE = 8 - ๋ธ๋ก๋น ์ค๋ ๋ ์:
TPB = 8 - ๋ธ๋ก ์: 1
- ๊ณต์ ๋ฉ๋ชจ๋ฆฌ:
TPB๊ฐ ์์
์ฐธ๊ณ :
- ๋ฐ์ดํฐ ๋ก๋ฉ: ๊ฐ ์ค๋ ๋๊ฐ LayoutTensor ์ ๊ทผ์ ํตํด ์์ ํ๋๋ฅผ ๋ก๋
- ๋ฉ๋ชจ๋ฆฌ ํจํด: address_space๋ฅผ ์ง์ ํ LayoutTensor๋ก ์ค๊ฐ ๊ฒฐ๊ณผ๋ฅผ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ์ ์ ์ฅ
- ์ค๋ ๋ ๋๊ธฐํ: ์ฐ์ฐ ๋จ๊ณ ๊ฐ ์กฐ์จ
- ์ ๊ทผ ํจํด: ์คํธ๋ผ์ด๋ ๊ธฐ๋ฐ ๋ณ๋ ฌ ์ฐ์ฐ
- ํ์ ์์ ์ฑ: LayoutTensor์ ํ์ ์์คํ ํ์ฉ
์์ฑํ ์ฝ๋
comptime TPB = 8
comptime SIZE = 8
comptime BLOCKS_PER_GRID = (1, 1)
comptime THREADS_PER_BLOCK = (TPB, 1)
comptime dtype = DType.float32
comptime layout = Layout.row_major(SIZE)
fn prefix_sum_simple[
layout: Layout
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
a: LayoutTensor[dtype, layout, ImmutAnyOrigin],
size: UInt,
):
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
# FILL ME IN (roughly 18 lines)
์ ์ฒด ํ์ผ ๋ณด๊ธฐ: problems/p14/p14.mojo
ํ
- ๋ฐ์ดํฐ๋ฅผ
shared[local_i]์ ๋ก๋ offset = 1์์ ์์ํด ๋งค ๋จ๊ณ๋ง๋ค 2๋ฐฐ๋ก ์ฆ๊ฐlocal_i >= offset์ธ ์์์ ๋ํด ๋ง์ ์ํ- ๊ฐ ๋จ๊ณ ์ฌ์ด์
barrier()ํธ์ถ
์ฝ๋ ์คํ
์๋ฃจ์ ์ ํ ์คํธํ๋ ค๋ฉด ํฐ๋ฏธ๋์์ ๋ค์ ๋ช ๋ น์ด๋ฅผ ์คํํ์ธ์:
pixi run p14 --simple
pixi run -e amd p14 --simple
pixi run -e apple p14 --simple
uv run poe p14 --simple
ํผ์ฆ์ ์์ง ํ์ง ์์๋ค๋ฉด ์ถ๋ ฅ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
out: DeviceBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([0.0, 1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0])
์๋ฃจ์
fn prefix_sum_simple[
layout: Layout
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
a: LayoutTensor[dtype, layout, ImmutAnyOrigin],
size: UInt,
):
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
shared = LayoutTensor[
dtype,
Layout.row_major(TPB),
MutAnyOrigin,
address_space = AddressSpace.SHARED,
].stack_allocation()
if global_i < size:
shared[local_i] = a[global_i]
barrier()
offset = UInt(1)
for i in range(Int(log2(Scalar[dtype](TPB)))):
var current_val: output.element_type = 0
if local_i >= offset and local_i < size:
current_val = shared[local_i - offset] # read
barrier()
if local_i >= offset and local_i < size:
shared[local_i] += current_val
barrier()
offset *= 2
if global_i < size:
output[global_i] = shared[local_i]
๋ณ๋ ฌ (ํฌํจ) ๋์ ํฉ ์๊ณ ๋ฆฌ์ฆ์ ๋ค์๊ณผ ๊ฐ์ด ๋์ํฉ๋๋ค:
์ค์ ๋ฐ ๊ตฌ์ฑ
TPB(๋ธ๋ก๋น ์ค๋ ๋ ์) = 8SIZE(๋ฐฐ์ด ํฌ๊ธฐ) = 8
๊ฒฝ์ ์ํ ๋ฐฉ์ง
์ด ์๊ณ ๋ฆฌ์ฆ์ ๋ช ์์ ๋๊ธฐํ๋ฅผ ํตํด ์ฝ๊ธฐ-์ฐ๊ธฐ ์ถฉ๋์ ๋ฐฉ์งํฉ๋๋ค:
- ์ฝ๊ธฐ ๋จ๊ณ: ๋ชจ๋ ์ค๋ ๋๊ฐ ๋จผ์ ํ์ํ ๊ฐ์ ๋ก์ปฌ ๋ณ์
current_val์ ์ฝ์ด๋ - ๋๊ธฐํ:
barrier()๋ก ๋ชจ๋ ์ฝ๊ธฐ๊ฐ ์๋ฃ๋ ํ์์ผ ์ฐ๊ธฐ๊ฐ ์์๋๋๋ก ๋ณด์ฅ - ์ฐ๊ธฐ ๋จ๊ณ: ๋ชจ๋ ์ค๋ ๋๊ฐ ๊ณ์ฐ๋ ๊ฐ์ ์์ ํ๊ฒ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ์ ๊ธฐ๋ก
์ด๋ ๊ฒ ํ๋ฉด ์ฌ๋ฌ ์ค๋ ๋๊ฐ ๋์์ ๊ฐ์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์์น๋ฅผ ์ฝ๊ณ ์ธ ๋ ๋ฐ์ํ๋ ๊ฒฝ์ ์ํ๋ฅผ ๋ฐฉ์งํ ์ ์์ต๋๋ค.ในใ
๋์์ ์ ๊ทผ: ๊ฒฝ์ ์ํ๋ฅผ ๋ฐฉ์งํ๋ ๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ ๋๋ธ ๋ฒํผ๋ง ์ ๋๋ค. ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ 2๋ฐฐ๋ก ํ ๋นํ ๋ค, ํ ๋ฒํผ์์ ์ฝ๊ณ ๋ค๋ฅธ ๋ฒํผ์ ์ฐ๋ ๊ฒ์ ๋ฒ๊ฐ์ ์ํํ๋ ๋ฐฉ์์ ๋๋ค. ์ด ๋ฐฉ๋ฒ์ ๊ฒฝ์ ์ํ๋ฅผ ์์ ํ ์ ๊ฑฐํ์ง๋ง, ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ๋์ด๋๊ณ ๋ณต์ก๋๊ฐ ์ฌ๋ผ๊ฐ๋๋ค. ํ์ต ๋ชฉ์ ์ผ๋ก๋ ์ดํดํ๊ธฐ ๋ ์ฌ์ด ๋ช ์์ ๋๊ธฐํ ๋ฐฉ์์ ์ฌ์ฉํฉ๋๋ค.
์ค๋ ๋ ๋งคํ
thread_idx.x: \([0, 1, 2, 3, 4, 5, 6, 7]\) (local_i)block_idx.x: \([0, 0, 0, 0, 0, 0, 0, 0]\)global_i: \([0, 1, 2, 3, 4, 5, 6, 7]\) (block_idx.x * TPB + thread_idx.x)
๊ณต์ ๋ฉ๋ชจ๋ฆฌ์ ์ด๊ธฐ ๋ก๋
Threads: Tโ Tโ Tโ Tโ Tโ Tโ
Tโ Tโ
Input array: [0 1 2 3 4 5 6 7]
shared: [0 1 2 3 4 5 6 7]
โ โ โ โ โ โ โ โ
Tโ Tโ Tโ Tโ Tโ Tโ
Tโ Tโ
Offset = 1: ์ฒซ ๋ฒ์งธ ๋ณ๋ ฌ ๋จ๊ณ
ํ์ฑ ์ค๋ ๋: \(T_1 \ldots T_7\) (local_i โฅ 1์ธ ์ค๋ ๋)
์ฝ๊ธฐ ๋จ๊ณ: ๊ฐ ์ค๋ ๋๊ฐ ํ์ํ ๊ฐ์ ์ฝ์:
Tโ reads shared[0] = 0 Tโ
reads shared[4] = 4
Tโ reads shared[1] = 1 Tโ reads shared[5] = 5
Tโ reads shared[2] = 2 Tโ reads shared[6] = 6
Tโ reads shared[3] = 3
๋๊ธฐํ: barrier()๋ก ๋ชจ๋ ์ฝ๊ธฐ ์๋ฃ๋ฅผ ๋ณด์ฅ
์ฐ๊ธฐ ๋จ๊ณ: ๊ฐ ์ค๋ ๋๊ฐ ์ฝ์ ๊ฐ์ ํ์ฌ ์์น์ ๋ํจ:
Before: [0 1 2 3 4 5 6 7]
Add: +0 +1 +2 +3 +4 +5 +6
| | | | | | |
Result: [0 1 3 5 7 9 11 13]
โ โ โ โ โ โ โ
Tโ Tโ Tโ Tโ Tโ
Tโ Tโ
Offset = 2: ๋ ๋ฒ์งธ ๋ณ๋ ฌ ๋จ๊ณ
ํ์ฑ ์ค๋ ๋: \(T_2 \ldots T_7\) (local_i โฅ 2์ธ ์ค๋ ๋)
์ฝ๊ธฐ ๋จ๊ณ: ๊ฐ ์ค๋ ๋๊ฐ ํ์ํ ๊ฐ์ ์ฝ์:
Tโ reads shared[0] = 0 Tโ
reads shared[3] = 5
Tโ reads shared[1] = 1 Tโ reads shared[4] = 7
Tโ reads shared[2] = 3 Tโ reads shared[5] = 9
๋๊ธฐํ: barrier()๋ก ๋ชจ๋ ์ฝ๊ธฐ ์๋ฃ๋ฅผ ๋ณด์ฅ
์ฐ๊ธฐ ๋จ๊ณ: ๊ฐ ์ค๋ ๋๊ฐ ์ฝ์ ๊ฐ์ ๋ํจ:
Before: [0 1 3 5 7 9 11 13]
Add: +0 +1 +3 +5 +7 +9
| | | | | |
Result: [0 1 3 6 10 14 18 22]
โ โ โ โ โ โ
Tโ Tโ Tโ Tโ
Tโ Tโ
Offset = 4: ์ธ ๋ฒ์งธ ๋ณ๋ ฌ ๋จ๊ณ
ํ์ฑ ์ค๋ ๋: \(T_4 \ldots T_7\) (local_i โฅ 4์ธ ์ค๋ ๋)
์ฝ๊ธฐ ๋จ๊ณ: ๊ฐ ์ค๋ ๋๊ฐ ํ์ํ ๊ฐ์ ์ฝ์:
Tโ reads shared[0] = 0 Tโ reads shared[2] = 3
Tโ
reads shared[1] = 1 Tโ reads shared[3] = 6
๋๊ธฐํ: barrier()๋ก ๋ชจ๋ ์ฝ๊ธฐ ์๋ฃ๋ฅผ ๋ณด์ฅ
์ฐ๊ธฐ ๋จ๊ณ: ๊ฐ ์ค๋ ๋๊ฐ ์ฝ์ ๊ฐ์ ๋ํจ:
Before: [0 1 3 6 10 14 18 22]
Add: +0 +1 +3 +6
| | | |
Result: [0 1 3 6 10 15 21 28]
โ โ โ โ
Tโ Tโ
Tโ Tโ
์ต์ข ๊ฒฐ๊ณผ๋ฅผ output์ ๊ธฐ๋ก
Threads: Tโ Tโ Tโ Tโ Tโ Tโ
Tโ Tโ
global_i: 0 1 2 3 4 5 6 7
output: [0 1 3 6 10 15 21 28]
โ โ โ โ โ โ โ โ
Tโ Tโ Tโ Tโ Tโ Tโ
Tโ Tโ
์ฃผ์ ๊ตฌํ ์์ธ
๋๊ธฐํ ํจํด: ๊ฐ ๋ฐ๋ณต์ ์๊ฒฉํ ์ฝ๊ธฐ โ ๋๊ธฐํ โ ์ฐ๊ธฐ ํจํด์ ๋ฐ๋ฆ ๋๋ค:
var current_val: out.element_type = 0- ๋ก์ปฌ ๋ณ์ ์ด๊ธฐํcurrent_val = shared[local_i - offset]- ์ฝ๊ธฐ ๋จ๊ณ (์กฐ๊ฑด ์ถฉ์กฑ ์)barrier()- ๊ฒฝ์ ์ํ ๋ฐฉ์ง๋ฅผ ์ํ ๋ช ์์ ๋๊ธฐํshared[local_i] += current_val- ์ฐ๊ธฐ ๋จ๊ณ (์กฐ๊ฑด ์ถฉ์กฑ ์)barrier()- ๋ค์ ๋ฐ๋ณต ์ ๋๊ธฐํ
๊ฒฝ์ ์ํ ๋ฐฉ์ง: ์ฝ๊ธฐ์ ์ฐ๊ธฐ๋ฅผ ๋ช ์์ ์ผ๋ก ๋ถ๋ฆฌํ์ง ์์ผ๋ฉด ์ฌ๋ฌ ์ค๋ ๋๊ฐ ๋์์ ๊ฐ์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์์น์ ์ ๊ทผํ์ฌ ๋ฏธ์ ์ ๋์์ด ๋ฐ์ํ ์ ์์ต๋๋ค. ๋ช ์์ ๋๊ธฐํ๋ฅผ ์ฌ์ฉํ 2๋จ๊ณ ์ ๊ทผ ๋ฐฉ์์ด ์ ํ์ฑ์ ๋ณด์ฅํฉ๋๋ค.
๋ฉ๋ชจ๋ฆฌ ์์ ์ฑ: ์๊ณ ๋ฆฌ์ฆ์ ๋ค์์ ํตํด ๋ฉ๋ชจ๋ฆฌ ์์ ์ฑ์ ์ ์งํฉ๋๋ค:
if local_i >= offset and local_i < size๋ก ๊ฒฝ๊ณ ๊ฒ์ฌ- ์์ ๋ณ์์ ์ ์ ํ ์ด๊ธฐํ
- ๊ฒฝ์ ์ํ๋ฅผ ๋ฐฉ์งํ๋ ์กฐ์จ๋ ์ ๊ทผ ํจํด
์ด ์๋ฃจ์
์ barrier()๋ฅผ ์ฌ์ฉํด ๋จ๊ณ ๊ฐ ์ฌ๋ฐ๋ฅธ ๋๊ธฐํ๋ฅผ ๋ณด์ฅํ๊ณ , if global_i < size๋ก ๋ฐฐ์ด ๊ฒฝ๊ณ ๊ฒ์ฌ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค. ์ต์ข
๊ฒฐ๊ณผ๋ ๊ฐ ์์ \(i\)๊ฐ \(\sum_{j=0}^{i} a[j]\) ๋ฅผ ํฌํจํ๋ ํฌํจ ๋์ ํฉ์
๋๋ค.