μΈμ μν νλ‘κ·Έλλ°μ μ¬μ©ν κΉ
λΉ λ₯Έ νλ¨ κ°μ΄λ
β μν μ°μ°μ μ¬μ©ν λ:
- 32κ° μ΄μμ μμμ λν 리λμ
μ°μ° (
sum,max,min) - κ·μΉμ μΈ λ©λͺ¨λ¦¬ μ κ·Ό ν¨ν΄ (μΈμ λ μΈ β μΈμ μ£Όμ)
- ν¬λ‘μ€ μν€ν μ² μ΄μμ±μ΄ νμν κ²½μ° (NVIDIA/RDNA 32 vs CDNA 64 μ€λ λ)
- λ κ°λ¨νκ³ μ μ§λ³΄μνκΈ° μ¬μ΄ μ½λλ₯Ό μν λ
β κΈ°μ‘΄ λ°©μμ μ¬μ©ν λ:
- 볡μ‘ν μν κ° λκΈ°νκ° νμν κ²½μ°
- λΆκ·μΉνκ±°λ μ°λ°μ μΈ λ©λͺ¨λ¦¬ μ κ·Ό ν¨ν΄
- μ€λ λλ³ μμ λμ΄ λ€λ₯Έ κ²½μ° (μν λΆκΈ° λ°μ)
- λ¬Έμ ν¬κΈ°κ°
size < WARP_SIZEμΈ κ²½μ°
μ±λ₯ νΉμ±
λ¬Έμ ν¬κΈ°λ³ νμ₯μ±
| μμ μ | μν μ΄μ | λΉκ³ |
|---|---|---|
| < 32 | μμ | κΈ°μ‘΄ λ°©μμ΄ μ 리 |
| 32-1K | 1.2-1.5λ°° | μ΄μ μ΄ λνλκΈ° μμ |
| 1K-32K | 1.5-2.5λ°° | μν μ°μ°μ΄ νμ |
| > 32K | λ©λͺ¨λ¦¬ λ°μ΄λ | μμͺ½ λͺ¨λ λμνμ μν΄ μ ν |
μνμ ν΅μ¬ μ΄μ
- λκΈ°ν μ€λ²ν€λ μ λ‘: λ°°λ¦¬μ΄ λΉμ© μ κ±°
- μ΅μνμ λ©λͺ¨λ¦¬ μ¬μ©: 곡μ λ©λͺ¨λ¦¬ ν λΉ λΆνμ
- μ°μν νμ₯μ±: μν μκ° λμ΄λ μλ‘ μ±λ₯ ν₯μ
- κ°κ²°ν μ½λ: λ μ μ μ€ μ, λ μ μ μ€λ₯ κ°λ₯μ±
μκ³ λ¦¬μ¦λ³ κ°μ΄λ
| μκ³ λ¦¬μ¦ | κΆμ₯ μ¬ν | μ΄μ |
|---|---|---|
| λ΄μ | μν μ°μ° (1K+ μμ) | λ¨μΌ 리λμ , κ·μΉμ μ κ·Ό |
| νλ ¬ ν/μ΄ ν©κ³ | μν μ°μ° | μμ°μ€λ¬μ΄ 리λμ ν¨ν΄ |
| λμ ν© | νμ prefix_sum() μ¬μ© | νλμ¨μ΄ μ΅μ νλ κΈ°λ³Έ μμ |
| νλ§ (max/min) | μν μ°μ° (κ·μΉμ μλμ°) | ν¨μ¨μ μΈ μλμ° λ¦¬λμ |
| ꡬκ°μ΄ λ§μ νμ€ν κ·Έλ¨ | κΈ°μ‘΄ λ°©μ | λΆκ·μΉν μ°κΈ°, μμμ μ λ°μ΄νΈ |
μ½λ μμ
β μνμ μ ν©ν κ²½μ°
# 리λμ
μ°μ°
from gpu.primitives.warp import sum, max
var total = sum(partial_values)
var maximum = max(partial_values)
# ν΅μ ν¨ν΄
from gpu.primitives.warp import shuffle_idx, prefix_sum
var broadcast = shuffle_idx(my_value, 0)
var running_sum = prefix_sum(my_value)
β κΈ°μ‘΄ λ°©μμ΄ λμ κ²½μ°
# 볡μ‘ν λ€λ¨κ³ λκΈ°ν
stage1_compute()
barrier() # λͺ¨λ μ€λ λκ° μλ£λ λκΉμ§ λκΈ°
stage2_depends_on_stage1()
# λΆκ·μΉν λ©λͺ¨λ¦¬ μ κ·Ό
var value = input[random_indices[global_i]] # μ°λ°μ μ½κΈ°
# λ°μ΄ν° μμ‘΄μ μμ
if input[global_i] > threshold:
result = expensive_computation() # μν λΆκΈ° λ°μ
μ±λ₯ μΈ‘μ
# νμ μμͺ½ λ°©μμ λ²€μΉλ§ν¬νμΈμ
mojo p22.mojo --benchmark
# νμ₯ ν¨ν΄μ νμΈνμΈμ:
# traditional_1x: X.XX ms
# warp_1x: Y.YY ms # λ λΉ¨λΌμΌ ν¨
# warp_32x: Z.ZZ ms # μ΄μ μ΄ μ»€μ ΈμΌ ν¨
μμ½
μν μ°μ°μΌλ‘ μμνμΈμ:
- κ·μΉμ μΈ μ κ·Ό ν¨ν΄μ κ°μ§ 리λμ
- λ¬Έμ β₯ 1 μν ν¬κΈ°
- ν¬λ‘μ€ νλ«νΌ νΈνμ±μ΄ νμν κ²½μ°
κΈ°μ‘΄ λ°©μμ μ¬μ©νμΈμ:
- 볡μ‘ν λκΈ°νκ° νμν κ²½μ°
- λΆκ·μΉν λ©λͺ¨λ¦¬ ν¨ν΄
- μμ λ¬Έμ λλ μ¬ν λΆκΈ°
νλ¨μ΄ μ΄λ €μΈ λ: μμͺ½ λͺ¨λ ꡬννκ³ λ²€μΉλ§ν¬νμΈμ. μ±λ₯ μ°¨μ΄λ₯Ό 보면 λ΅μ΄ λμ΅λλ€.