๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๋ฒ์
๊ฐ์
์ ๋ฐฉ ํ๋ ฌ \(A\) ์ \(B\) ์ ํ๋ ฌ ๊ณฑ์ ์ ๊ตฌํํ๊ณ ๊ฒฐ๊ณผ๋ฅผ \(\text{output}\)์ ์ ์ฅํ๋ ํผ์ฆ์ ๋๋ค. ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ์ฉํ์ฌ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ ํจํด์ ์ต์ ํํฉ๋๋ค. ์ฐ์ฐ ์ ์ ํ๋ ฌ ๋ธ๋ก์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ์ ๋ฏธ๋ฆฌ ๋ก๋ํ๋ ๋ฐฉ์์ ๋๋ค.
ํต์ฌ ๊ฐ๋
์ด ํผ์ฆ์์ ๋ค๋ฃจ๋ ๋ด์ฉ:
- LayoutTensor๋ฅผ ์ฌ์ฉํ ๋ธ๋ก ๋ก์ปฌ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ
- ์ค๋ ๋ ๋๊ธฐํ ํจํด
- ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ์ฉํ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ ์ต์ ํ
- 2D ์ธ๋ฑ์ฑ์ ์ฌ์ฉํ ํ๋ ฅ์ ๋ฐ์ดํฐ ๋ก๋ฉ
- ํ๋ ฌ ์ฐ์ฐ์ LayoutTensor๋ฅผ ํจ์จ์ ์ผ๋ก ํ์ฉํ๊ธฐ
ํต์ฌ์ LayoutTensor๋ฅผ ํตํด ๋น ๋ฅธ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ์ฉํ์ฌ ๋น์ฉ์ด ํฐ ์ ์ญ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ์ ์ต์ํํ๋ ๊ฒ์ ๋๋ค.
๊ตฌ์ฑ
- ํ๋ ฌ ํฌ๊ธฐ: \(\text{SIZE} \times \text{SIZE} = 2 \times 2\)
- ๋ธ๋ก๋น ์ค๋ ๋ ์: \(\text{TPB} \times \text{TPB} = 3 \times 3\)
- ๊ทธ๋ฆฌ๋ ์ฐจ์: \(1 \times 1\)
๋ ์ด์์ ๊ตฌ์ฑ:
- ์
๋ ฅ A:
Layout.row_major(SIZE, SIZE) - ์
๋ ฅ B:
Layout.row_major(SIZE, SIZE) - ์ถ๋ ฅ:
Layout.row_major(SIZE, SIZE) - ๊ณต์ ๋ฉ๋ชจ๋ฆฌ:
TPB ร TPBํฌ๊ธฐ์ LayoutTensor 2๊ฐ
๋ฉ๋ชจ๋ฆฌ ๊ตฌ์ฑ:
Global Memory (LayoutTensor): Shared Memory (LayoutTensor):
A[i,j]: Direct access a_shared[local_row, local_col]
B[i,j]: Direct access b_shared[local_row, local_col]
์์ฑํ ์ฝ๋
fn single_block_matmul[
layout: Layout, size: UInt
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
a: LayoutTensor[dtype, layout, ImmutAnyOrigin],
b: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
row = block_dim.y * block_idx.y + thread_idx.y
col = block_dim.x * block_idx.x + thread_idx.x
local_row = thread_idx.y
local_col = thread_idx.x
# FILL ME IN (roughly 12 lines)
์ ์ฒด ํ์ผ ๋ณด๊ธฐ: problems/p16/p16.mojo
ํ
- ์ ์ญ ์ธ๋ฑ์ค์ ๋ก์ปฌ ์ธ๋ฑ์ค๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ ฌ์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋
- ๋ก๋ ํ
barrier()ํธ์ถ - ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ธ๋ฑ์ค๋ฅผ ์ฌ์ฉํ์ฌ ๋ด์ ๊ณ์ฐ
- ๋ชจ๋ ์ฐ์ฐ์์ ๋ฐฐ์ด ๊ฒฝ๊ณ ๊ฒ์ฌ
์ฝ๋ ์คํ
์๋ฃจ์ ์ ํ ์คํธํ๋ ค๋ฉด ํฐ๋ฏธ๋์์ ๋ค์ ๋ช ๋ น์ด๋ฅผ ์คํํ์ธ์:
pixi run p16 --single-block
pixi run -e amd p16 --single-block
pixi run -e apple p16 --single-block
uv run poe p16 --single-block
ํผ์ฆ์ ์์ง ํ์ง ์์๋ค๋ฉด ์ถ๋ ฅ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
out: HostBuffer([0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([4.0, 6.0, 12.0, 22.0])
์๋ฃจ์
fn single_block_matmul[
layout: Layout, size: UInt
](
output: LayoutTensor[dtype, layout, MutAnyOrigin],
a: LayoutTensor[dtype, layout, ImmutAnyOrigin],
b: LayoutTensor[dtype, layout, ImmutAnyOrigin],
):
row = block_dim.y * block_idx.y + thread_idx.y
col = block_dim.x * block_idx.x + thread_idx.x
local_row = thread_idx.y
local_col = thread_idx.x
a_shared = LayoutTensor[
dtype,
Layout.row_major(TPB, TPB),
MutAnyOrigin,
address_space = AddressSpace.SHARED,
].stack_allocation()
b_shared = LayoutTensor[
dtype,
Layout.row_major(TPB, TPB),
MutAnyOrigin,
address_space = AddressSpace.SHARED,
].stack_allocation()
if row < size and col < size:
a_shared[local_row, local_col] = a[row, col]
b_shared[local_row, local_col] = b[row, col]
barrier()
if row < size and col < size:
var acc: output.element_type = 0
@parameter
for k in range(size):
acc += a_shared[local_row, k] * b_shared[k, local_col]
output[row, col] = acc
LayoutTensor๋ฅผ ํ์ฉํ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๊ตฌํ์ ํจ์จ์ ์ธ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ ํจํด์ ํตํด ์ฑ๋ฅ์ ํฅ์์ํต๋๋ค:
๋ฉ๋ชจ๋ฆฌ ๊ตฌ์ฑ
Input Tensors (2ร2): Shared Memory (3ร3):
Matrix A: a_shared:
[a[0,0] a[0,1]] [s[0,0] s[0,1] s[0,2]]
[a[1,0] a[1,1]] [s[1,0] s[1,1] s[1,2]]
[s[2,0] s[2,1] s[2,2]]
Matrix B: b_shared: (๋น์ทํ ๋ ์ด์์)
[b[0,0] b[0,1]] [t[0,0] t[0,1] t[0,2]]
[b[1,0] b[1,1]] [t[1,0] t[1,1] t[1,2]]
[t[2,0] t[2,1] t[2,2]]
๊ตฌํ ๋จ๊ณ
-
๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ค์ :
# address_space๋ฅผ ์ง์ ํ LayoutTensor๋ก 2D ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํ ์ ์์ฑ a_shared = LayoutTensor[dtype, Layout.row_major(TPB, TPB), MutAnyOrigin, address_space = AddressSpace.SHARED].stack_allocation() b_shared = LayoutTensor[dtype, Layout.row_major(TPB, TPB), MutAnyOrigin, address_space = AddressSpace.SHARED].stack_allocation() -
์ค๋ ๋ ์ธ๋ฑ์ฑ:
# ํ๋ ฌ ์ ๊ทผ์ ์ํ ์ ์ญ ์ธ๋ฑ์ค row = block_dim.y * block_idx.y + thread_idx.y col = block_dim.x * block_idx.x + thread_idx.x # ๊ณต์ ๋ฉ๋ชจ๋ฆฌ์ฉ ๋ก์ปฌ ์ธ๋ฑ์ค local_row = thread_idx.y local_col = thread_idx.x -
๋ฐ์ดํฐ ๋ก๋ฉ:
# LayoutTensor ์ธ๋ฑ์ฑ์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋ if row < size and col < size: a_shared[local_row, local_col] = a[row, col] b_shared[local_row, local_col] = b[row, col] -
๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ ์ฐ์ฐ:
# ๊ฐ๋๋ก ์ ํจํ ํ๋ ฌ ์์๋ง ๊ณ์ฐ if row < size and col < size: # ์ถ๋ ฅ ํ ์์ ํ์ ์ผ๋ก ๋์ ๋ณ์ ์ด๊ธฐํ var acc: output.element_type = 0 # ์ปดํ์ผ ํ์์ ์ ๊ฐ๋๋ ํ๋ ฌ ๊ณฑ์ ๋ฃจํ @parameter for k in range(size): acc += a_shared[local_row, k] * b_shared[k, local_col] # ํ๋ ฌ ๊ฒฝ๊ณ ๋ด์ ์ค๋ ๋๋ง ๊ฒฐ๊ณผ ๊ธฐ๋ก output[row, col] = acc์ฃผ์ ํฌ์ธํธ:
-
๊ฒฝ๊ณ ๊ฒ์ฌ:
if row < size and col < size- ๋ฒ์ ๋ฐ ์ฐ์ฐ ๋ฐฉ์ง
- ์ ํจํ ์ค๋ ๋๋ง ์์ ์ํ
- TPB (3ร3) > SIZE (2ร2)์ด๋ฏ๋ก ํ์
-
๋์ ๋ณ์ ํ์ :
var acc: output.element_type- ์ถ๋ ฅ ํ ์์ ์์ ํ์ ์ผ๋ก ํ์ ์์ ์ฑ ํ๋ณด
- ์ผ๊ด๋ ์์น ์ ๋ฐ๋ ๋ณด์ฅ
- ๋์ ์ ์ 0์ผ๋ก ์ด๊ธฐํ
-
๋ฃจํ ์ต์ ํ:
@parameter for k in range(size)- ์ปดํ์ผ ํ์์ ๋ฃจํ ์ ๊ฐ
- ๋ ๋์ ๋ช ๋ น์ด ์ค์ผ์ค๋ง ๊ฐ๋ฅ
- ํฌ๊ธฐ๊ฐ ์๊ณ ๋ฏธ๋ฆฌ ์๋ ค์ง ํ๋ ฌ์ ํจ๊ณผ์
-
๊ฒฐ๊ณผ ๊ธฐ๋ก:
output[row, col] = acc- ๋์ผํ ๊ฐ๋ ์กฐ๊ฑด์ผ๋ก ๋ณดํธ
- ์ ํจํ ์ค๋ ๋๋ง ๊ฒฐ๊ณผ ๊ธฐ๋ก
- ํ๋ ฌ ๊ฒฝ๊ณ ์์ ์ฑ ์ ์ง
-
์ค๋ ๋ ์์ ์ฑ๊ณผ ๋๊ธฐํ
-
๊ฐ๋ ์กฐ๊ฑด:
- ์
๋ ฅ ๋ก๋ฉ:
if row < size and col < size - ์ฐ์ฐ: ๋์ผํ ๊ฐ๋๋ก ์ค๋ ๋ ์์ ์ฑ ๋ณด์ฅ
- ์ถ๋ ฅ ๊ธฐ๋ก: ๊ฐ์ ์กฐ๊ฑด์ผ๋ก ๋ณดํธ
- ์๋ชป๋ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ๊ณผ ๊ฒฝ์ ์ํ ๋ฐฉ์ง
- ์
๋ ฅ ๋ก๋ฉ:
-
๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ ์์ ์ฑ:
- ๊ณต์ ๋ฉ๋ชจ๋ฆฌ: TPB ๋ฒ์ ๋ด์์๋ง ์ ๊ทผ
- ์ ์ญ ๋ฉ๋ชจ๋ฆฌ: ํฌ๊ธฐ ๊ฒ์ฌ๋ก ๋ณดํธ
- ์ถ๋ ฅ: ๊ฐ๋๋ ์ฐ๊ธฐ๋ก ๋ฐ์ดํฐ ์์ ๋ฐฉ์ง
์ฃผ์ ์ธ์ด ๊ธฐ๋ฅ
-
LayoutTensor์ ์ฅ์ :
- ์ง์ 2D ์ธ๋ฑ์ฑ์ผ๋ก ์ฝ๋ ๋จ์ํ
element_type์ ํตํ ํ์ ์์ ์ฑ- ํจ์จ์ ์ธ ๋ฉ๋ชจ๋ฆฌ ๋ ์ด์์ ์ฒ๋ฆฌ
-
๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํ ๋น:
- address_space๋ฅผ ์ง์ ํ LayoutTensor๋ก ๊ตฌ์กฐํ๋ ํ ๋น
- ์ ๋ ฅ ํ ์์ ๋์ผํ ํ ์ฐ์ ๋ ์ด์์
- ํจ์จ์ ์ ๊ทผ์ ์ํ ์ ์ ํ ๋ฉ๋ชจ๋ฆฌ ์ ๋ ฌ
-
๋๊ธฐํ:
barrier()๋ก ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ผ๊ด์ฑ ๋ณด์ฅ- ๋ก๋์ ์ฐ์ฐ ๊ฐ ์ ์ ํ ๋๊ธฐํ
- ๋ธ๋ก ๋ด ์ค๋ ๋ ๊ฐ ํ๋ ฅ
์ฑ๋ฅ ์ต์ ํ
-
๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ ํจ์จ:
- ์์๋น ์ ์ญ ๋ฉ๋ชจ๋ฆฌ ๋ก๋ 1ํ
- ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํตํ ๋ค์ค ์ฌ์ฌ์ฉ
- ๋ณํฉ๋(coalesced) ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ ํจํด
-
์ค๋ ๋ ํ๋ ฅ:
- ํ๋ ฅ์ ๋ฐ์ดํฐ ๋ก๋ฉ
- ๊ณต์ ๋ฐ์ดํฐ ์ฌ์ฌ์ฉ
- ํจ์จ์ ์ธ ์ค๋ ๋ ๋๊ธฐํ
-
์ฐ์ฐ ์ด์ :
- ์ ์ญ ๋ฉ๋ชจ๋ฆฌ ํธ๋ํฝ ๊ฐ์
- ์บ์ ํ์ฉ๋ ํฅ์
- ๋ช ๋ น์ด ์ฒ๋ฆฌ๋ ๊ฐ์
์ด ๊ตฌํ์ ๋ค์์ ํตํด ๊ธฐ๋ณธ ๋ฒ์ ๋๋น ์ฑ๋ฅ์ ํฌ๊ฒ ํฅ์์ํต๋๋ค:
- ์ ์ญ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ ํ์ ๊ฐ์
- ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํตํ ๋ฐ์ดํฐ ์ฌ์ฌ์ฉ
- LayoutTensor์ ํจ์จ์ ์ธ 2D ์ธ๋ฑ์ฑ ํ์ฉ
- ์ ์ ํ ์ค๋ ๋ ๋๊ธฐํ ์ ์ง