warp.sum()μ ν΅μ¬ - μν λ 벨 λ΄μ
Puzzle 12μμ μ΄ν΄λ³Έ λ΄μ μ Mojoμ μν μ°μ°μΌλ‘ ꡬνν©λλ€. 볡μ‘ν 곡μ λ©λͺ¨λ¦¬ ν¨ν΄μ κ°λ¨ν ν¨μ νΈμΆλ‘ λ체ν©λλ€. κ° μν λ μΈμ΄ νλμ μμλ₯Ό μ²λ¦¬νκ³ warp.sum()μΌλ‘ κ²°κ³Όλ₯Ό μλμΌλ‘ ν©μ°νμ¬, μν νλ‘κ·Έλλ°μ΄ GPU λκΈ°νλ₯Ό μ΄λ»κ² λ³ννλμ§ λ³΄μ¬μ€λλ€.
ν΅μ¬ ν΅μ°°: warp.sum() μ°μ°μ SIMT μ€νμ νμ©νμ¬ κ³΅μ λ©λͺ¨λ¦¬ + λ°°λ¦¬μ΄ + νΈλ¦¬ 리λμ μ λ¨μΌ νλμ¨μ΄ κ°μ λͺ λ ΉμΌλ‘ λ체ν©λλ€.
ν΅μ¬ κ°λ
μ΄ νΌμ¦μμ λ°°μΈ λ΄μ©:
warp.sum()μ νμ©ν μν λ 벨 리λμ - SIMT μ€ν λͺ¨λΈκ³Ό λ μΈ λκΈ°ν
WARP_SIZEλ₯Ό νμ©ν ν¬λ‘μ€ μν€ν μ² νΈνμ±- 볡μ‘ν ν¨ν΄μμ κ°λ¨ν ν¨ν΄μΌλ‘μ μ±λ₯ λ³ν
- λ μΈ ID κ΄λ¦¬μ μ‘°κ±΄λΆ μ°κΈ°
μνμ μ°μ°μ λ΄μ μ λλ€: \[\Large \text{output}[0] = \sum_{i=0}^{N-1} a[i] \times b[i]\]
νμ§λ§ ꡬν κ³Όμ μμ Mojoμ λͺ¨λ μν λ 벨 GPU νλ‘κ·Έλλ°μ μ μ©λλ κΈ°λ³Έ ν¨ν΄μ λ°°μλλ€.
ꡬμ±
- λ²‘ν° ν¬κΈ°:
SIZE = WARP_SIZE(GPU μν€ν μ²μ λ°λΌ 32 λλ 64) - λ°μ΄ν° νμ
:
DType.float32 - λΈλ‘ ꡬμ±:
(WARP_SIZE, 1)λΈλ‘λΉ μ€λ λ μ - 그리λ ꡬμ±:
(1, 1)그리λλΉ λΈλ‘ μ - λ μ΄μμ:
Layout.row_major(SIZE)(1D ν μ°μ )
κΈ°μ‘΄ λ°©μμ 볡μ‘μ± (Puzzle 12μμ)
solutions/p12/p12.mojoμ 볡μ‘ν λ°©μμ λ μ¬λ € λ΄ μλ€. 곡μ λ©λͺ¨λ¦¬, 배리μ΄, νΈλ¦¬ 리λμ μ΄ νμνμ΅λλ€:
comptime SIZE = WARP_SIZE
comptime BLOCKS_PER_GRID = (1, 1)
comptime THREADS_PER_BLOCK = (WARP_SIZE, 1)
comptime dtype = DType.float32
comptime SIMD_WIDTH = simd_width_of[dtype]()
comptime in_layout = Layout.row_major(SIZE)
comptime out_layout = Layout.row_major(1)
fn traditional_dot_product_p12_style[
in_layout: Layout, out_layout: Layout, size: Int
](
output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
a: LayoutTensor[dtype, in_layout, ImmutAnyOrigin],
b: LayoutTensor[dtype, in_layout, ImmutAnyOrigin],
):
"""
This is the complex approach from p12_layout_tensor.mojo - kept for comparison.
"""
shared = LayoutTensor[
dtype,
Layout.row_major(WARP_SIZE),
MutAnyOrigin,
address_space = AddressSpace.SHARED,
].stack_allocation()
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
local_i = Int(thread_idx.x)
if global_i < size:
shared[local_i] = (a[global_i] * b[global_i]).reduce_add()
else:
shared[local_i] = 0.0
barrier()
stride = WARP_SIZE // 2
while stride > 0:
if local_i < stride:
shared[local_i] += shared[local_i + stride]
barrier()
stride //= 2
if local_i == 0:
output[global_i // WARP_SIZE] = shared[0]
μ΄ λ°©μμ΄ λ³΅μ‘ν μ΄μ :
- 곡μ λ©λͺ¨λ¦¬ ν λΉ: λΈλ‘ λ΄μμ μλμΌλ‘ λ©λͺ¨λ¦¬λ₯Ό κ΄λ¦¬
- λͺ
μμ 배리μ΄: μ€λ λ λκΈ°νλ₯Ό μν
barrier()νΈμΆ - νΈλ¦¬ 리λμ : μ€νΈλΌμ΄λ κΈ°λ° μΈλ±μ±μ μ¬μ©νλ 볡μ‘ν 루ν
- μ‘°κ±΄λΆ μ°κΈ°: μ€λ λ 0λ§ μ΅μ’ κ²°κ³Όλ₯Ό κΈ°λ‘
λμμ νμ§λ§, μ½λκ° μ₯ν©νκ³ μ€λ₯κ° λ°μνκΈ° μ¬μ°λ©° GPU λκΈ°νμ λν κΉμ μ΄ν΄κ° νμν©λλ€.
κΈ°μ‘΄ λ°©μ ν μ€νΈ:
pixi run p24 --traditional
pixi run -e amd p24 --traditional
pixi run -e apple p24 --traditional
uv run poe p24 --traditional
μμ±ν μ½λ
1. κ°λ¨ν μν 컀λ λ°©μ
볡μ‘ν κΈ°μ‘΄ λ°©μμ warp_sum()μ μ¬μ©νλ κ°λ¨ν μν 컀λλ‘ λ³νν©λλ€:
fn simple_warp_dot_product[
in_layout: Layout, out_layout: Layout, size: Int
](
output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
a: LayoutTensor[dtype, in_layout, ImmutAnyOrigin],
b: LayoutTensor[dtype, in_layout, ImmutAnyOrigin],
):
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
# FILL IN (6 lines at most)
μ 체 νμΌ λ³΄κΈ°: problems/p24/p24.mojo
ν
1. κ°λ¨ν μν 컀λ ꡬ쑰 μ΄ν΄νκΈ°
simple_warp_dot_product ν¨μλ₯Ό 6μ€ μ΄λ΄λ‘ μμ±ν΄μΌ ν©λλ€:
fn simple_warp_dot_product[...](output, a, b):
global_i = block_dim.x * block_idx.x + thread_idx.x
# μ¬κΈ°λ₯Ό μ±μ°μΈμ (μ΅λ 6μ€)
λ°λΌμΌ ν ν¨ν΄:
- μ΄ μ€λ λμ μμμ λν λΆλΆκ³± κ³μ°
warp_sum()μΌλ‘ λͺ¨λ μν λ μΈμ κ°μ ν©μ°- λ μΈ 0μ΄ μ΅μ’ κ²°κ³Όλ₯Ό κΈ°λ‘
2. λΆλΆκ³± κ³μ°νκΈ°
var partial_product: Scalar[dtype] = 0
if global_i < size:
partial_product = (a[global_i] * b[global_i]).reduce_add()
.reduce_add()κ° νμν μ΄μ : Mojoμ κ°μ SIMD κΈ°λ°μ΄λ―λ‘ a[global_i] * b[global_i]λ SIMD 벑ν°λ₯Ό λ°νν©λλ€. .reduce_add()λ‘ λ²‘ν°λ₯Ό μ€μΉΌλΌ κ°μΌλ‘ ν©μ°ν©λλ€.
κ²½κ³ κ²μ¬: λͺ¨λ μ€λ λκ° μ ν¨ν λ°μ΄ν°λ₯Ό κ°μ§κ³ μμ§ μμ μ μμΌλ―λ‘ νμμ μ λλ€.
3. μν 리λμ μ λ§λ²
total = warp_sum(partial_product)
warp_sum()μ΄ νλ μΌ:
- κ° λ μΈμ
partial_productκ°μ κ°μ Έμ΄ - μν λ΄ λͺ¨λ λ μΈμ κ°μ ν©μ° (νλμ¨μ΄ κ°μ)
- λͺ¨λ λ μΈμ κ°μ ν©κ³λ₯Ό λ°ν (λ μΈ 0λ§μ΄ μλ)
- λͺ μμ λκΈ°νκ° μ ν νμ μμ (SIMTκ° μ²λ¦¬)
4. κ²°κ³Ό κΈ°λ‘νκΈ°
if lane_id() == 0:
output[global_i // WARP_SIZE] = total
μ λ μΈ 0λ§? warp_sum() μ΄ν λͺ¨λ λ μΈμ΄ κ°μ total κ°μ κ°μ§λ§, κ²½μ μνλ₯Ό νΌνκΈ° μν΄ ν λ²λ§ κΈ°λ‘ν©λλ€.
μ output[0]μ μ§μ μ°μ§ μμκΉ? μ μ°μ±μ μν΄μμ
λλ€. μ΄ ν¨μλ μνκ° μ¬λ¬ κ°μΈ κ²½μ°μλ μ¬μ©ν μ μμΌλ©°, κ° μνμ κ²°κ³Όκ° global_i // WARP_SIZE μμΉμ κΈ°λ‘λ©λλ€.
lane_id(): 0-31 (NVIDIA) λλ 0-63 (AMD)μ λ°ν - μν λ΄μμ μ΄λ λ μΈμΈμ§ μλ³ν©λλ€.
κ°λ¨ν μν 컀λ ν μ€νΈ:
uv run poe p24 --kernel
pixi run p24 --kernel
νμμ λμ μμ μΆλ ₯:
SIZE: 32
WARP_SIZE: 32
SIMD_WIDTH: 8
=== RESULT ===
out: 10416.0
expected: 10416.0
π Notice how simple the warp version is compared to p12.mojo!
Same kernel structure, but warp_sum() replaces all the complexity!
νμ΄
fn simple_warp_dot_product[
in_layout: Layout, out_layout: Layout, size: Int
](
output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
a: LayoutTensor[dtype, in_layout, ImmutAnyOrigin],
b: LayoutTensor[dtype, in_layout, ImmutAnyOrigin],
):
global_i = Int(block_dim.x * block_idx.x + thread_idx.x)
# Each thread computes one partial product using vectorized approach as values in Mojo are SIMD based
var partial_product: Scalar[dtype] = 0
if global_i < size:
partial_product = (a[global_i] * b[global_i]).reduce_add()
# warp_sum() replaces all the shared memory + barriers + tree reduction
total = warp_sum(partial_product)
# Only lane 0 writes the result (all lanes have the same total)
if lane_id() == 0:
output[global_i // WARP_SIZE] = total
κ°λ¨ν μν 컀λμ 볡μ‘ν λκΈ°νμμ νλμ¨μ΄ κ°μ κΈ°λ³Έ μμλ‘μ κ·Όλ³Έμ μΈ λ³νμ 보μ¬μ€λλ€:
κΈ°μ‘΄ λ°©μμμ μ¬λΌμ§ κ²λ€:
- 15μ€ μ΄μ β 6μ€: νκΈ°μ μΈ μ½λ μΆμ
- 곡μ λ©λͺ¨λ¦¬ ν λΉ: λ©λͺ¨λ¦¬ κ΄λ¦¬ λΆνμ
- 3ν μ΄μμ barrier() νΈμΆ: λͺ μμ λκΈ°ν μ λ‘
- 볡μ‘ν νΈλ¦¬ 리λμ : λ¨μΌ ν¨μ νΈμΆλ‘ λ체
- μ€νΈλΌμ΄λ κΈ°λ° μΈλ±μ±: μμ ν μ κ±°
SIMT μ€ν λͺ¨λΈ:
μν λ μΈ (SIMT μ€ν):
λ μΈ 0: partial_product = a[0] * b[0] = 0.0
λ μΈ 1: partial_product = a[1] * b[1] = 4.0
λ μΈ 2: partial_product = a[2] * b[2] = 16.0
...
λ μΈ 31: partial_product = a[31] * b[31] = 3844.0
warp_sum() νλμ¨μ΄ μ°μ°:
λͺ¨λ λ μΈ β 0.0 + 4.0 + 16.0 + ... + 3844.0 = 10416.0
λͺ¨λ λ μΈμ΄ μμ β total = 10416.0 (λΈλ‘λμΊμ€νΈ κ²°κ³Ό)
λ°°λ¦¬μ΄ μμ΄ λμνλ μ΄μ :
- SIMT μ€ν: λͺ¨λ λ μΈμ΄ κ° λͺ λ Ή λμ μ€ν
- νλμ¨μ΄ λκΈ°ν:
warp_sum()μ΄ μμλ λ λͺ¨λ λ μΈμ΄ μ΄λ―Έpartial_productκ³μ° μλ£ - λ΄μ₯ ν΅μ : GPU νλμ¨μ΄κ° 리λμ μ°μ° μ²λ¦¬
- λΈλ‘λμΊμ€νΈ κ²°κ³Ό: λͺ¨λ λ μΈμ΄ κ°μ
totalκ° μμ
2. ν¨μν λ°©μ
μ΄λ²μλ Mojoμ ν¨μν νλ‘κ·Έλλ° ν¨ν΄μ μ¬μ©νμ¬ κ°μ μν λ΄μ μ ꡬνν©λλ€:
fn functional_warp_dot_product[
layout: Layout,
out_layout: Layout,
dtype: DType,
simd_width: Int,
rank: Int,
size: Int,
](
output: LayoutTensor[mut=True, dtype, out_layout, MutAnyOrigin],
a: LayoutTensor[mut=False, dtype, layout, MutAnyOrigin],
b: LayoutTensor[mut=False, dtype, layout, MutAnyOrigin],
ctx: DeviceContext,
) raises:
@parameter
@always_inline
fn compute_dot_product[
simd_width: Int, rank: Int, alignment: Int = align_of[dtype]()
](indices: IndexList[rank]) capturing -> None:
idx = indices[0]
print("idx:", idx)
# FILL IN (10 lines at most)
# Launch exactly size == WARP_SIZE threads (one warp) to process all elements
elementwise[compute_dot_product, 1, target="gpu"](size, ctx)
ν
1. ν¨μν λ°©μμ ꡬ쑰 μ΄ν΄νκΈ°
compute_dot_product ν¨μλ₯Ό 10μ€ μ΄λ΄λ‘ μμ±ν΄μΌ ν©λλ€:
@parameter
@always_inline
fn compute_dot_product[simd_width: Int, rank: Int](indices: IndexList[rank]) capturing -> None:
idx = indices[0]
# μ¬κΈ°λ₯Ό μ±μ°μΈμ (μ΅λ 10μ€)
ν¨μν ν¨ν΄μ μ°¨μ΄μ :
elementwiseλ₯Ό μ¬μ©νμ¬ μ ννWARP_SIZEκ°μ μ€λ λ μ€ν- κ° μ€λ λκ°
idxλ₯Ό κΈ°λ°μΌλ‘ νλμ μμ μ²λ¦¬ - κ°μ μν μ°μ°, λ€λ₯Έ μ€ν λ©μ»€λμ¦
2. λΆλΆκ³± κ³μ°νκΈ°
var partial_product: Scalar[dtype] = 0.0
if idx < size:
a_val = a.load[1](idx, 0)
b_val = b.load[1](idx, 0)
partial_product = (a_val * b_val).reduce_add()
else:
partial_product = 0.0
λ‘λ© ν¨ν΄: a.load[1](idx, 0)μ μμΉ idxμμ μ νν 1κ° μμλ₯Ό λ‘λν©λλ€ (SIMD 벑ν°ν μμ).
κ²½κ³ μ²λ¦¬: λ²μλ₯Ό λ²μ΄λ μ€λ λμ partial_productλ₯Ό 0.0μΌλ‘ μ€μ νμ¬ ν©μ°μ κΈ°μ¬νμ§ μλλ‘ ν©λλ€.
3. μν μ°μ°κ³Ό μ μ₯
total = warp_sum(partial_product)
if lane_id() == 0:
output.store[1](Index(idx // WARP_SIZE), total)
μ μ₯ ν¨ν΄: output.store[1](Index(idx // WARP_SIZE), 0, total)μ μΆλ ₯ ν
μμ μμΉ (idx // WARP_SIZE, 0)μ 1κ° μμλ₯Ό μ μ₯ν©λλ€.
λμΌν μν λ‘μ§: warp_sum()κ³Ό λ μΈ 0μ κΈ°λ‘ λ‘μ§μ ν¨μν λ°©μμμλ λμΌνκ² λμν©λλ€.
4. importμμ μ¬μ© κ°λ₯ν ν¨μλ€
from gpu import lane_id
from gpu.primitives.warp import sum as warp_sum, WARP_SIZE
# ν¨μ λ΄μμ:
my_lane = lane_id() # 0 ~ WARP_SIZE-1
total = warp_sum(my_value) # νλμ¨μ΄ κ°μ 리λμ
warp_size = WARP_SIZE # 32 (NVIDIA) λλ 64 (AMD)
ν¨μν λ°©μ ν μ€νΈ:
uv run poe p24 --functional
pixi run p24 --functional
νμμ λμ μμ μΆλ ₯:
SIZE: 32
WARP_SIZE: 32
SIMD_WIDTH: 8
=== RESULT ===
out: 10416.0
expected: 10416.0
π§ Functional approach shows modern Mojo style with warp operations!
Clean, composable, and still leverages warp hardware primitives!
νμ΄
fn functional_warp_dot_product[
layout: Layout,
out_layout: Layout,
dtype: DType,
simd_width: Int,
rank: Int,
size: Int,
](
output: LayoutTensor[mut=True, dtype, out_layout, MutAnyOrigin],
a: LayoutTensor[mut=False, dtype, layout, MutAnyOrigin],
b: LayoutTensor[mut=False, dtype, layout, MutAnyOrigin],
ctx: DeviceContext,
) raises:
@parameter
@always_inline
fn compute_dot_product[
simd_width: Int, rank: Int, alignment: Int = align_of[dtype]()
](indices: IndexList[rank]) capturing -> None:
idx = indices[0]
# Each thread computes one partial product
var partial_product: Scalar[dtype] = 0.0
if idx < size:
a_val = a.load[1](Index(idx))
b_val = b.load[1](Index(idx))
partial_product = a_val * b_val
else:
partial_product = 0.0
# Warp magic - combines all WARP_SIZE partial products!
total = warp_sum(partial_product)
# Only lane 0 writes the result (all lanes have the same total)
if lane_id() == 0:
output.store[1](Index(idx // WARP_SIZE), total)
# Launch exactly size == WARP_SIZE threads (one warp) to process all elements
elementwise[compute_dot_product, 1, target="gpu"](size, ctx)
ν¨μν μν λ°©μμ μν μ°μ°μ νμ©ν νλμ μΈ Mojo νλ‘κ·Έλλ° ν¨ν΄μ 보μ¬μ€λλ€:
ν¨μν λ°©μμ νΉμ§:
elementwise[compute_dot_product, 1, target="gpu"](size, ctx)
μ₯μ :
- νμ μμ μ±: μ»΄νμΌ νμ ν μ λ μ΄μμ κ²μ¬
- μ‘°ν© κ°λ₯μ±: λ€λ₯Έ ν¨μν μ°μ°κ³Ό μ½κ² ν΅ν©
- νλμ ν¨ν΄: Mojoμ ν¨μν νλ‘κ·Έλλ° κΈ°λ₯ νμ©
- μλ μ΅μ ν: μ»΄νμΌλ¬κ° κ³ μμ€ μ΅μ νλ₯Ό μ μ© κ°λ₯
컀λ λ°©μκ³Όμ μ£Όμ μ°¨μ΄:
- μ€ν λ©μ»€λμ¦:
enqueue_functionλμelementwiseμ¬μ© - λ©λͺ¨λ¦¬ μ κ·Ό:
.load[1]()κ³Ό.store[1]()ν¨ν΄ μ¬μ© - ν΅ν©μ±: λ€λ₯Έ ν¨μν μ°μ°κ³Ό μμ°μ€λ½κ² κ²°ν©
λμΌν μνμ μ΄μ :
- λκΈ°ν μ λ‘:
warp_sum()μ΄ λμΌνκ² λμ - νλμ¨μ΄ κ°μ: 컀λ λ°©μκ³Ό κ°μ μ±λ₯
- ν¬λ‘μ€ μν€ν
μ²:
WARP_SIZEκ° μλμΌλ‘ μ μ
λ²€μΉλ§ν¬λ₯Ό ν΅ν μ±λ₯ λΉκ΅
μ’ ν© λ²€μΉλ§ν¬λ₯Ό μ€ννμ¬ μν μ°μ°μ νμ₯μ±μ νμΈν©λλ€:
uv run poe p24 --benchmark
pixi run p24 --benchmark
μ 체 λ²€μΉλ§ν¬ μ€ν κ²°κ³Όμ μμμ λλ€:
SIZE: 32
WARP_SIZE: 32
SIMD_WIDTH: 8
--------------------------------------------------------------------------------
Testing SIZE=1 x WARP_SIZE, BLOCKS=1
Running traditional_1x
Running simple_warp_1x
Running functional_warp_1x
--------------------------------------------------------------------------------
Testing SIZE=4 x WARP_SIZE, BLOCKS=4
Running traditional_4x
Running simple_warp_4x
Running functional_warp_4x
--------------------------------------------------------------------------------
Testing SIZE=32 x WARP_SIZE, BLOCKS=32
Running traditional_32x
Running simple_warp_32x
Running functional_warp_32x
--------------------------------------------------------------------------------
Testing SIZE=256 x WARP_SIZE, BLOCKS=256
Running traditional_256x
Running simple_warp_256x
Running functional_warp_256x
--------------------------------------------------------------------------------
Testing SIZE=2048 x WARP_SIZE, BLOCKS=2048
Running traditional_2048x
Running simple_warp_2048x
Running functional_warp_2048x
--------------------------------------------------------------------------------
Testing SIZE=16384 x WARP_SIZE, BLOCKS=16384 (Large Scale)
Running traditional_16384x
Running simple_warp_16384x
Running functional_warp_16384x
--------------------------------------------------------------------------------
Testing SIZE=65536 x WARP_SIZE, BLOCKS=65536 (Massive Scale)
Running traditional_65536x
Running simple_warp_65536x
Running functional_warp_65536x
| name | met (ms) | iters |
| ---------------------- | --------------------- | ----- |
| traditional_1x | 0.00460128 | 100 |
| simple_warp_1x | 0.00574047 | 100 |
| functional_warp_1x | 0.00484192 | 100 |
| traditional_4x | 0.00492671 | 100 |
| simple_warp_4x | 0.00485247 | 100 |
| functional_warp_4x | 0.00587679 | 100 |
| traditional_32x | 0.0062406399999999996 | 100 |
| simple_warp_32x | 0.0054918400000000004 | 100 |
| functional_warp_32x | 0.00552447 | 100 |
| traditional_256x | 0.0050614300000000004 | 100 |
| simple_warp_256x | 0.00488768 | 100 |
| functional_warp_256x | 0.00461472 | 100 |
| traditional_2048x | 0.01120031 | 100 |
| simple_warp_2048x | 0.00884383 | 100 |
| functional_warp_2048x | 0.007038720000000001 | 100 |
| traditional_16384x | 0.038533750000000005 | 100 |
| simple_warp_16384x | 0.0323264 | 100 |
| functional_warp_16384x | 0.01674271 | 100 |
| traditional_65536x | 0.19784991999999998 | 100 |
| simple_warp_65536x | 0.12870176 | 100 |
| functional_warp_65536x | 0.048680310000000004 | 100 |
Benchmarks completed!
WARP OPERATIONS PERFORMANCE ANALYSIS:
GPU Architecture: NVIDIA (WARP_SIZE=32) vs AMD (WARP_SIZE=64)
- 1,...,256 x WARP_SIZE: Grid size too small to benchmark
- 2048 x WARP_SIZE: Warp primative benefits emerge
- 16384 x WARP_SIZE: Large scale (512K-1M elements)
- 65536 x WARP_SIZE: Massive scale (2M-4M elements)
Expected Results at Large Scales:
β’ Traditional: Slower due to more barrier overhead
β’ Warp operations: Faster, scale better with problem size
β’ Memory bandwidth becomes the limiting factor
μ΄ μμμμ μ»μ μ μλ μ±λ₯ μΈμ¬μ΄νΈ:
- μκ·λͺ¨ (1x-4x): μν μ°μ°μ΄ μνμ κ°μ μ 보μ (~10-15% λΉ λ¦)
- μ€κ·λͺ¨ (32x-256x): ν¨μν λ°©μμ΄ κ°μ₯ μ’μ μ±λ₯μ 보μ΄λ κ²½μ°κ° λ§μ
- λκ·λͺ¨ (16K-65K): λ©λͺ¨λ¦¬ λμνμ΄ μ§λ°°μ μ΄ λλ©΄μ λͺ¨λ λ°©μμ μ±λ₯μ΄ μλ ΄
- λ³λμ±: μ±λ₯μ νΉμ GPU μν€ν μ²μ λ©λͺ¨λ¦¬ μλΈμμ€ν μ ν¬κ² μμ‘΄
μ°Έκ³ : νλμ¨μ΄(GPU λͺ¨λΈ, λ©λͺ¨λ¦¬ λμν, WARP_SIZE)μ λ°λΌ κ²°κ³Όκ° ν¬κ² λ¬λΌμ§λλ€. ν΅μ¬μ μ λμ μΈ μμΉλ³΄λ€ μλμ μΈ μ±λ₯ μΆμΈλ₯Ό κ΄μ°°νλ κ²μ
λλ€.
λ€μ λ¨κ³
warp.sum μ°μ°μ λ°°μ μΌλ, λ€μμΌλ‘ μ§νν μ μμ΅λλ€:
- μΈμ μν νλ‘κ·Έλλ°μ μ¬μ©ν κΉ: μν vs κΈ°μ‘΄ λ°©μμ λν μ λ΅μ μμ¬κ²°μ νλ μμν¬
- κ³ κΈ μν μ°μ°: 볡μ‘ν ν΅μ ν¨ν΄μ μν
shuffle_idx(),shuffle_down(),prefix_sum() - λ©ν° μν μκ³ λ¦¬μ¦: μν μ°μ°κ³Ό λΈλ‘ λ 벨 λκΈ°νμ κ²°ν©
- λ©λͺ¨λ¦¬ λ³ν© μ΅μ ν: μ΅λ λμνμ μν λ©λͺ¨λ¦¬ μ κ·Ό ν¨ν΄ μ΅μ ν
π‘ ν΅μ¬ μμ : μν μ°μ°μ 볡μ‘ν λκΈ°ν ν¨ν΄μ νλμ¨μ΄ κ°μ κΈ°λ³Έ μμλ‘ λ체νμ¬ GPU νλ‘κ·Έλλ°μ λ³νν©λλ€. μ€ν λͺ¨λΈμ μ΄ν΄νλ©΄ μ±λ₯μ ν¬μνμ§ μκ³ λ νκΈ°μ μΈ λ¨μνκ° κ°λ₯ν©λλ€.