๊ฐ์
1D LayoutTensor a์์ ๊ฐ ์์น์ ์ง์ 3๊ฐ ๊ฐ์ ํฉ์ ๊ณ์ฐํ์ฌ 1D LayoutTensor output์ ์ ์ฅํ๋ ์ปค๋์ ๊ตฌํํ์ธ์.
์ฐธ๊ณ : ๊ฐ ์์น๋ง๋ค ์ค๋ ๋ 1๊ฐ๊ฐ ์์ต๋๋ค. ์ค๋ ๋๋น ์ ์ญ ์ฝ๊ธฐ 1ํ, ์ ์ญ ์ฐ๊ธฐ 1ํ๋ง ํ์ํฉ๋๋ค.
ํต์ฌ ๊ฐ๋
์ด ํผ์ฆ์์ ๋ฐฐ์ธ ๋ด์ฉ:
- LayoutTensor๋ก ์ฌ๋ผ์ด๋ฉ ์๋์ฐ ์ฐ์ฐ ๊ตฌํํ๊ธฐ
- Puzzle 8์์ ๋ค๋ฃฌ LayoutTensor ์ฃผ์ ๊ณต๊ฐ(address_space)์ผ๋ก ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌํ๊ธฐ
- ํจ์จ์ ์ธ ์ด์ ์ ๊ทผ ํจํด
- ๊ฒฝ๊ณ ์กฐ๊ฑด ์ฒ๋ฆฌ
ํต์ฌ์ LayoutTensor๊ฐ ํจ์จ์ ์ธ ์๋์ฐ ๊ธฐ๋ฐ ์ฐ์ฐ์ ์ ์งํ๋ฉด์๋ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ๋ฅผ ๊ฐ์ํํ๋ ๋ฐฉ๋ฒ์ ๋๋ค.
๊ตฌ์ฑ
- ๋ฐฐ์ด ํฌ๊ธฐ:
SIZE = 8 - ๋ธ๋ก๋น ์ค๋ ๋ ์:
TPB = 8 - ์๋์ฐ ํฌ๊ธฐ: 3
- ๊ณต์ ๋ฉ๋ชจ๋ฆฌ:
TPB๊ฐ
์ฐธ๊ณ :
- LayoutTensor ํ ๋น:
LayoutTensor[dtype, Layout.row_major(TPB), MutAnyOrigin, address_space = AddressSpace.SHARED].stack_allocation()์ฌ์ฉ - ์๋์ฐ ์ ๊ทผ: 3๊ฐ์ง๋ฆฌ ์๋์ฐ์ ์์ฐ์ค๋ฌ์ด ์ธ๋ฑ์ฑ
- ๊ฒฝ๊ณ ์ฒ๋ฆฌ: ์ฒ์ ๋ ์์น๋ ํน์ ์ผ์ด์ค
- ๋ฉ๋ชจ๋ฆฌ ํจํด: ์ค๋ ๋๋น ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๋ก๋ 1ํ
์์ฑํ ์ฝ๋
comptime TPB = 8
comptime SIZE = 8
comptime BLOCKS_PER_GRID = (1, 1)
comptime THREADS_PER_BLOCK = (TPB, 1)
comptime dtype = DType.float32
comptime layout = Layout.row_major(SIZE)
fn pooling[
layout: Layout
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
a: LayoutTensor[dtype, layout, ImmutAnyOrigin],
size: UInt,
):
# Allocate shared memory using tensor builder
shared = LayoutTensor[
dtype,
Layout.row_major(TPB),
MutAnyOrigin,
address_space = AddressSpace.SHARED,
].stack_allocation()
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
# FIX ME IN (roughly 10 lines)
์ ์ฒด ํ์ผ ๋ณด๊ธฐ: problems/p11/p11_layout_tensor.mojo
ํ
- LayoutTensor์ ์ฃผ์ ๊ณต๊ฐ(address_space)์ผ๋ก ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์์ฑ
- ์์ฐ์ค๋ฌ์ด ์ธ๋ฑ์ฑ์ผ๋ก ๋ฐ์ดํฐ ๋ก๋:
shared[local_i] = a[global_i] - ์ฒ์ ๋ ์์น๋ฅผ ํน์ ์ผ์ด์ค๋ก ์ฒ๋ฆฌ
- ์๋์ฐ ์ฐ์ฐ์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํ์ฉ
- ๊ฒฝ๊ณ ์ด๊ณผ ์ ๊ทผ์ ๊ฐ๋ ์ถ๊ฐ
์ฝ๋ ์คํ
์๋ฃจ์ ์ ํ ์คํธํ๋ ค๋ฉด ํฐ๋ฏธ๋์์ ๋ค์ ๋ช ๋ น์ด๋ฅผ ์คํํ์ธ์:
pixi run p11_layout_tensor
pixi run -e amd p11_layout_tensor
pixi run -e apple p11_layout_tensor
uv run poe p11_layout_tensor
ํผ์ฆ์ ์์ง ํ์ง ์์๋ค๋ฉด ์ถ๋ ฅ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
out: HostBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([0.0, 1.0, 3.0, 6.0, 9.0, 12.0, 15.0, 18.0])
์๋ฃจ์
fn pooling[
layout: Layout
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
a: LayoutTensor[dtype, layout, ImmutAnyOrigin],
size: UInt,
):
# Allocate shared memory using tensor builder
shared = LayoutTensor[
dtype,
Layout.row_major(TPB),
MutAnyOrigin,
address_space = AddressSpace.SHARED,
].stack_allocation()
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
# Load data into shared memory
if global_i < size:
shared[local_i] = a[global_i]
# Synchronize threads within block
barrier()
# Handle first two special cases
if global_i == 0:
output[0] = shared[0]
elif global_i == 1:
output[1] = shared[0] + shared[1]
# Handle general case
elif UInt(1) < global_i < size:
output[global_i] = (
shared[local_i - 2] + shared[local_i - 1] + shared[local_i]
)
LayoutTensor๋ฅผ ํ์ฉํ ์ฌ๋ผ์ด๋ฉ ์๋์ฐ ํฉ๊ณ ๊ตฌํ์ ๋๋ค. ์ฃผ์ ๋จ๊ณ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
-
๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ค์
-
LayoutTensor๊ฐ ์ฃผ์ ๊ณต๊ฐ(address_space)์ผ๋ก ๋ธ๋ก ๋ก์ปฌ ์ ์ฅ์๋ฅผ ์์ฑ:
shared = LayoutTensor[dtype, Layout.row_major(TPB), MutAnyOrigin, address_space = AddressSpace.SHARED].stack_allocation() -
๊ฐ ์ค๋ ๋๊ฐ ํ๋์ฉ ๋ก๋:
Input array: [0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0] Block shared: [0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0] -
barrier()๋ก ๋ชจ๋ ๋ฐ์ดํฐ ๋ก๋ ์๋ฃ๋ฅผ ๋ณด์ฅ
-
-
๊ฒฝ๊ณ ์ผ์ด์ค
-
์์น 0: ํ๋๋ง
output[0] = shared[0] = 0.0 -
์์น 1: ์ฒ์ ๋ ๊ฐ์ ํฉ
output[1] = shared[0] + shared[1] = 0.0 + 1.0 = 1.0
-
-
๋ฉ์ธ ์๋์ฐ ์ฐ์ฐ
-
์์น 2 ์ดํ:
Position 2: shared[0] + shared[1] + shared[2] = 0.0 + 1.0 + 2.0 = 3.0 Position 3: shared[1] + shared[2] + shared[3] = 1.0 + 2.0 + 3.0 = 6.0 Position 4: shared[2] + shared[3] + shared[4] = 2.0 + 3.0 + 4.0 = 9.0 ... -
LayoutTensor์ ์์ฐ์ค๋ฌ์ด ์ธ๋ฑ์ฑ:
# 3๊ฐ์ง๋ฆฌ ์ฌ๋ผ์ด๋ฉ ์๋์ฐ window_sum = shared[i-2] + shared[i-1] + shared[i]
-
-
๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ ํจํด
- ์ค๋ ๋๋ง๋ค ๊ณต์ ํ ์๋ก ์ ์ญ ์ฝ๊ธฐ 1ํ
- ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํตํ ํจ์จ์ ์ธ ์ด์ ์ ๊ทผ
- LayoutTensor์ ์ฅ์ :
- ์๋ ๊ฒฝ๊ณ ๊ฒ์ฌ
- ์์ฐ์ค๋ฌ์ด ์๋์ฐ ์ธ๋ฑ์ฑ
- ๋ ์ด์์์ ์ธ์ํ๋ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ
- ์ ๊ณผ์ ์ ๊ฑธ์น ํ์ ์์ ์ฑ
๊ณต์ ๋ฉ๋ชจ๋ฆฌ์ ์ฑ๋ฅ๊ณผ LayoutTensor์ ์์ ์ฑ ๋ฐ ํธ์์ฑ์ ๊ฒฐํฉํ ๋ฐฉ์์ ๋๋ค:
- ์ ์ญ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ ์ต์ํ
- ์๋์ฐ ์ฐ์ฐ ๊ฐ์ํ
- ๊น๋ํ ๊ฒฝ๊ณ ์ฒ๋ฆฌ
- ๋ณํฉ ์ ๊ทผ ํจํด ์ ์ง
์ต์ข ์ถ๋ ฅ์ ๋์ ์๋์ฐ ํฉ๊ณ์ ๋๋ค:
[0.0, 1.0, 3.0, 6.0, 9.0, 12.0, 15.0, 18.0]