๊ฐ์
๋ฒกํฐ a์ ๊ฐ ์์น์ 10์ ๋ํด output์ ์ ์ฅํ๋ ์ปค๋์ ๊ตฌํํด ๋ณด์ธ์.
์ฐธ๊ณ : ๋ธ๋ก๋น ์ค๋ ๋ ์๊ฐ a์ ํฌ๊ธฐ๋ณด๋ค ์์ต๋๋ค.
ํต์ฌ ๊ฐ๋
์ด ํผ์ฆ์์ ๋ฐฐ์ธ ๋ด์ฉ:
- ์ค๋ ๋ ๋ธ๋ก ๋ด์์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉํ๊ธฐ
- ๋ฐฐ๋ฆฌ์ด(barrier)๋ก ์ค๋ ๋ ๋๊ธฐํํ๊ธฐ
- ๋ธ๋ก ๋ก์ปฌ ๋ฐ์ดํฐ ์ ์ฅ์ ๊ด๋ฆฌํ๊ธฐ
ํต์ฌ์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋ธ๋ก ๋ด ๋ชจ๋ ์ค๋ ๋๊ฐ ์ ๊ทผํ ์ ์๋ ๋น ๋ฅธ ๋ก์ปฌ ์ ์ฅ์๋ผ๋ ์ , ๊ทธ๋ฆฌ๊ณ ์ด๋ฅผ ์ฌ์ฉํ ๋ ์ค๋ ๋ ๊ฐ ์กฐ์จ์ด ํ์ํ๋ค๋ ์ ์ ์ดํดํ๋ ๊ฒ์ ๋๋ค.
๊ตฌ์ฑ
- ๋ฐฐ์ด ํฌ๊ธฐ:
SIZE = 8์์ - ๋ธ๋ก๋น ์ค๋ ๋ ์:
TPB = 4 - ๋ธ๋ก ์: 2
- ๊ณต์ ๋ฉ๋ชจ๋ฆฌ: ๋ธ๋ก๋น
TPB๊ฐ ์์
์ฐธ๊ณ :
- ๊ณต์ ๋ฉ๋ชจ๋ฆฌ: ๋ธ๋ก ๋ด ์ค๋ ๋๋ค์ด ํจ๊ป ์ฌ์ฉํ๋ ๋น ๋ฅธ ์ ์ฅ์
- ์ค๋ ๋ ๋๊ธฐํ:
barrier()๋ฅผ ์ฌ์ฉํ ์กฐ์จ - ๋ฉ๋ชจ๋ฆฌ ๋ฒ์: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ ๋ธ๋ก ๋ด์์๋ง ๋ณด์
- ์ ๊ทผ ํจํด: ๋ก์ปฌ ์ธ๋ฑ์ค vs ์ ์ญ ์ธ๋ฑ์ค
์ฃผ์: ๊ฐ ๋ธ๋ก์ด ๊ฐ์ง ์ ์๋ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํฌ๊ธฐ๋ ์์ ๋ก ์ ํด์ ธ์ผ ํฉ๋๋ค. ์ด ๊ฐ์ ๋ณ์๊ฐ ์๋ ๋ฆฌํฐ๋ด Python ์์์ฌ์ผ ํฉ๋๋ค. ๊ณต์ ๋ฉ๋ชจ๋ฆฌ์ ์ด ํ์๋ barrier๋ฅผ ํธ์ถํ์ฌ ์ค๋ ๋๋ค์ด ์๋ก ์์๊ฐ์ง ์๋๋ก ํด์ผ ํฉ๋๋ค.
ํ์ต ์ฐธ๊ณ : ์ด ํผ์ฆ์์๋ ๊ฐ ์ค๋ ๋๊ฐ ์์ ์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์์น์๋ง ์ ๊ทผํ๋ฏ๋ก barrier()๊ฐ ์๋ฐํ ํ์ํ์ง ์์ต๋๋ค. ํ์ง๋ง ๋ ๋ณต์กํ ์ํฉ์์ ํ์ํ ์ฌ๋ฐ๋ฅธ ๋๊ธฐํ ํจํด์ ์ตํ๊ธฐ ์ํด ํฌํจ๋์ด ์์ต๋๋ค.
์์ฑํ ์ฝ๋
comptime TPB = 4
comptime SIZE = 8
comptime BLOCKS_PER_GRID = (2, 1)
comptime THREADS_PER_BLOCK = (TPB, 1)
comptime dtype = DType.float32
fn add_10_shared(
output: UnsafePointer[Scalar[dtype], MutAnyOrigin],
a: UnsafePointer[Scalar[dtype], MutAnyOrigin],
size: UInt,
):
shared = stack_allocation[
TPB,
Scalar[dtype],
address_space = AddressSpace.SHARED,
]()
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
# Load local data into shared memory
if global_i < size:
shared[local_i] = a[global_i]
# wait for all threads to complete
# works within a thread block
barrier()
# FILL ME IN (roughly 2 lines)
์ ์ฒด ์ฝ๋ ๋ณด๊ธฐ: problems/p08/p08.mojo
ํ
barrier()๋ก ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๋ก๋ ์๋ฃ ๋๊ธฐ (ํ์ต์ฉ - ์ฌ๊ธฐ์๋ ์๋ฐํ ํ์ํ์ง ์์)local_i๋ก ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ:shared[local_i]global_i๋ก ์ถ๋ ฅ:output[global_i]- ๊ฐ๋ ์ถ๊ฐ:
if global_i < size
์ฝ๋ ์คํ
์๋ฃจ์ ์ ํ ์คํธํ๋ ค๋ฉด ํฐ๋ฏธ๋์์ ๋ค์ ๋ช ๋ น์ด๋ฅผ ์คํํ์ธ์:
pixi run p08
pixi run -e amd p08
pixi run -e apple p08
uv run poe p08
ํผ์ฆ์ ์์ง ํ์ง ์์๋ค๋ฉด ์ถ๋ ฅ์ด ๋ค์๊ณผ ๊ฐ์ด ๋ํ๋ฉ๋๋ค:
out: HostBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0])
์๋ฃจ์
fn add_10_shared(
output: UnsafePointer[Scalar[dtype], MutAnyOrigin],
a: UnsafePointer[Scalar[dtype], MutAnyOrigin],
size: UInt,
):
shared = stack_allocation[
TPB,
Scalar[dtype],
address_space = AddressSpace.SHARED,
]()
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
# Load local data into shared memory
if global_i < size:
shared[local_i] = a[global_i]
# Wait for all threads to complete (works within a thread block).
# Note: barrier is not strictly needed here since each thread only accesses
# its own shared memory location. However, it's included to teach proper
# shared memory synchronization patterns for more complex scenarios where
# threads need to coordinate access to shared data.
barrier()
# process using shared memory
if global_i < size:
output[global_i] = shared[local_i] + 10
GPU ํ๋ก๊ทธ๋๋ฐ์์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ์ ํต์ฌ ๊ฐ๋ ์ ๋ณด์ฌ์ฃผ๋ ์๋ฃจ์ ์ ๋๋ค:
-
๋ฉ๋ชจ๋ฆฌ ๊ณ์ธต ๊ตฌ์กฐ
-
์ ์ญ ๋ฉ๋ชจ๋ฆฌ:
a์output๋ฐฐ์ด (๋๋ฆผ, ๋ชจ๋ ๋ธ๋ก์์ ๋ณด์) -
๊ณต์ ๋ฉ๋ชจ๋ฆฌ:
shared๋ฐฐ์ด (๋น ๋ฆ, ์ค๋ ๋ ๋ธ๋ก ๋ก์ปฌ) -
๋ธ๋ก๋น 4๊ฐ ์ค๋ ๋๋ก 8๊ฐ ์์๋ฅผ ์ฒ๋ฆฌํ๋ ์์:
์ ์ญ ๋ฐฐ์ด a: [1 1 1 1 | 1 1 1 1] # ์ ๋ ฅ: ๋ชจ๋ 1 Block (0): Block (1): shared[0..3] shared[0..3] [1 1 1 1] [1 1 1 1]
-
-
์ค๋ ๋ ์กฐ์จ
-
๋ก๋ ๋จ๊ณ:
Thread 0: shared[0] = a[0]=1 Thread 2: shared[2] = a[2]=1 Thread 1: shared[1] = a[1]=1 Thread 3: shared[3] = a[3]=1 barrier() โ โ โ โ # ๋ชจ๋ ๋ก๋ ์๋ฃ ๋๊ธฐ -
์ฒ๋ฆฌ ๋จ๊ณ: ๊ฐ ์ค๋ ๋๊ฐ ์์ ์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๊ฐ์ 10์ ๋ํจ
-
๊ฒฐ๊ณผ:
output[i] = shared[local_i] + 10 = 11
์ฐธ๊ณ : ์ด ๊ฒฝ์ฐ์๋ ๊ฐ ์ค๋ ๋๊ฐ ์์ ์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์์น(
shared[local_i])์๋ง ์ฐ๊ณ ์ฝ์ผ๋ฏ๋กbarrier()๊ฐ ์๋ฐํ ํ์ํ์ง ์์ต๋๋ค. ํ์ง๋ง ์ค๋ ๋๋ค์ด ์๋ก์ ๋ฐ์ดํฐ์ ์ ๊ทผํ๋ ์ํฉ์์ ํ์์ ์ธ ๋๊ธฐํ ํจํด์ ์ตํ๊ธฐ ์ํด ํฌํจ๋์ด ์์ต๋๋ค. -
-
์ธ๋ฑ์ค ๋งคํ
-
์ ์ญ ์ธ๋ฑ์ค:
block_dim.x * block_idx.x + thread_idx.xBlock 0 ์ถ๋ ฅ: [11 11 11 11] Block 1 ์ถ๋ ฅ: [11 11 11 11] -
๋ก์ปฌ ์ธ๋ฑ์ค: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ์
thread_idx.x์ฌ์ฉ๋ ๋ธ๋ก ๋ชจ๋ ์ฒ๋ฆฌ: 1 + 10 = 11
-
-
๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ ํจํด
- ๋ก๋: ์ ์ญ โ ๊ณต์ (๋ณํฉ ์ฝ๊ธฐ๋ก 1 ๊ฐ๋ค ๋ก๋)
- ๋๊ธฐํ:
barrier()๋ก ๋ชจ๋ ๋ก๋ ์๋ฃ ๋ณด์ฅ - ์ฒ๋ฆฌ: ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๊ฐ์ 10 ๋ํ๊ธฐ
- ์ ์ฅ: ๊ฒฐ๊ณผ(11)๋ฅผ ์ ์ญ ๋ฉ๋ชจ๋ฆฌ์ ์ฐ๊ธฐ
์ด ํจํด์ ๋ธ๋ก ๋ด ์ค๋ ๋ ์กฐ์จ์ ์ ์งํ๋ฉด์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ก ๋ฐ์ดํฐ ์ ๊ทผ์ ์ต์ ํํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.