Puzzle 19: μ΄ν μ Op
κ°μ
μ΄ νΌμ¦μμλ μ΄ν μ λ©μ»€λμ¦μ 컀μ€ν MAX κ·Έλν μ°μ°μΌλ‘ ꡬνν©λλ€. μ΄ν μ μ νΈλμ€ν¬λ¨Έμ ν¨κ» λ리 μλ €μ§ νλ μ κ²½λ§μ ν΅μ¬ μμλ‘, λͺ¨λΈμ΄ μμΈ‘ν λ μ λ ₯μμ κ΄λ ¨λ λΆλΆμ μ§μ€ν μ μκ² ν΄μ€λλ€.
μνμ μΌλ‘ μ΄ν μ ν¨μλ λ€μκ³Ό κ°μ΄ μ μλ©λλ€:
$$\Large \text{Attention}(Q, K, V) = \text{softmax}(Q \cdot K^T) \cdot V$$
μ¬κΈ°μ:
- \(Q\)λ \((d,)~\) ννμ 쿼리 λ²‘ν° - μ°ΎμΌλ €λ λμμ λνλ λλ€
- \(K\)λ \((\text{seq_len}, d)~\) ννμ ν€ νλ ¬ - λ§€μΉν μ μλ λμμ λνλ λλ€
- \(V\)λ \((\text{seq_len}, d)~\) ννμ κ° νλ ¬ - κ²μν μ 보λ₯Ό λνλ λλ€
- μΆλ ₯μ \((d,)\) ννμ κ°μ€ν© 벑ν°μ λλ€
μ°μ°μ μΈ κ°μ§ μ£Όμ λ¨κ³λ‘ μ΄λ£¨μ΄μ§λλ€:
- μ΄ν μ μ μ: \(Q \cdot K^T\)λ₯Ό κ³μ°νμ¬ μΏΌλ¦¬κ° κ° ν€ λ²‘ν°μ μΌλ§λ μ λ§€μΉλλμ§ μΈ‘μ ν©λλ€
- μ΄ν μ κ°μ€μΉ: μννΈλ§₯μ€λ₯Ό μ μ©νμ¬ μ μλ₯Ό νλ₯ λΆν¬λ‘ λ³νν©λλ€ (κ°μ€μΉμ ν© = 1)
- κ°μ€ ν©: μ΄ν μ κ°μ€μΉλ₯Ό μ¬μ©νμ¬ κ° λ²‘ν°λ€μ κ²°ν©ν΄ μ΅μ’ μΆλ ₯μ μμ±ν©λλ€
μ΄ν μ μ΄ν΄νκΈ°: λ¨κ³λ³ λΆμ
μ΄ν μ μ μ€λ§νΈ κ²μ λ©μ»€λμ¦μΌλ‘ μκ°ν΄ 보μΈμ. 쿼리(μ°Ύκ³ μ νλ κ²)κ° μ£Όμ΄μ§λ©΄, μ΄ν μ μ ν€-κ° μμ λͺ¨μμμ κ°μ₯ κ΄λ ¨μ± λμ μ 보λ₯Ό μ°Ύμλ λλ€:
-
1λ¨κ³ - μ μ¬λ λ§€μΉ: 쿼리 \(Q\)λ₯Ό λͺ¨λ ν€ \(K\)μ λΉκ΅νμ¬ μ μ¬λ μ μλ₯Ό ꡬν©λλ€
- \(Q \cdot K^T\)λ₯Ό κ³μ°νμ¬ \(Q\)κ° κ° ν€ λ²‘ν°μ μΌλ§λ μ λ§€μΉλλμ§ μΈ‘μ ν©λλ€
- λμ μ μ = λ μ’μ λ§€μΉ
-
2λ¨κ³ - νλ₯ λΆν¬: μμ μ μλ₯Ό μ κ·νλ κ°μ€μΉλ‘ λ³νν©λλ€
- μννΈλ§₯μ€λ₯Ό μ μ©νμ¬ λͺ¨λ κ°μ€μΉμ ν©μ΄ 1.0μ΄ λλλ‘ ν©λλ€
- μ΄λ€ κ°μ μ§μ€ν μ§μ λν νλ₯ λΆν¬λ₯Ό λ§λλλ€
-
3λ¨κ³ - κ°μ€ κ²μ: μ΄ν μ κ°μ€μΉλ₯Ό μ¬μ©νμ¬ κ°λ€μ κ²°ν©ν©λλ€
- κ° κ° λ²‘ν°μ ν΄λΉνλ κ°μ€μΉλ₯Ό κ³±ν©λλ€
- λͺ¨λ κ²μ λν΄ μ΅μ’ μΆλ ₯μ ꡬν©λλ€
μ€μν λΉμ : λμκ΄μμ κ²μνλ κ²μ μμν΄ λ³΄μΈμ. 쿼리λ μ°Ύκ³ μΆμ κ²μ΄κ³ , μ± μ λͺ©μ ν€μ΄λ©°, μ± λ΄μ©μ κ°μ λλ€. μ΄ν μ μ κ° μ± μ΄ μΏΌλ¦¬μ μΌλ§λ κ΄λ ¨ μλμ§ κ³μ°ν λ€μ, κ΄λ ¨λμ λ°λΌ κ°μ€ μμ½μ μ 곡ν©λλ€.
μ°μ° νλ¦ μκ°ν
Input: Q(16,) K(16,16) V(16,16)
β β β
Step 1: Q(1,16) @ K^T(16,16) β Scores(1,16)
β
Step 2: softmax(Scores) β Weights(1,16) [sum = 1.0]
β
Step 3: Weights(1,16) @ V(16,16) β Output(1,16) β reshape β Output(16,)
ν΅μ¬ μμ΄λμ΄: 쿼리 λ²‘ν° \(Q\)λ₯Ό \((16,)\)μμ \((1,16)\)μΌλ‘ λ³ννλ©΄, λ΄μ λμ νλ ¬ κ³±μ μ μ¬μ©ν μ μμ΅λλ€. λλΆμ Puzzle 18μ κ³ λλ‘ μ΅μ νλ νμΌλ§ matmul 컀λμ κ·Έλλ‘ νμ©ν μ μμ΅λλ€!
GPU ꡬνμ μ΄μ νΌμ¦μμ μ΅μ νλ 컀λλ€μ μ¬μ¬μ©νκ³ κ²°ν©ν©λλ€:
- Puzzle 16μ νμΌλ§ νλ ¬ κ³±μ β ν¨μ¨μ μΈ \(Q \cdot K^T\) λ° \(\text{weights} \cdot V\) μ°μ°μ μ¬μ©
- 곡μ λ©λͺ¨λ¦¬ μ μΉ β \(K^T\)λ₯Ό ν¨μ¨μ μΌλ‘ κ³μ°
- Puzzle 18μ λ³λ ¬ μννΈλ§₯μ€ β μμΉμ μΌλ‘ μμ μ μΈ μ΄ν μ κ°μ€μΉ κ³μ°μ μ¬μ©
π 컀λ μ¬μ¬μ© μ λ΅: μ΄ νΌμ¦μ μ΄μ νΌμ¦μμ κ²μ¦λ μ΅μ ν 컀λλ€μ κ²°ν©νμ¬ λ³΅μ‘ν μ°μ°μ ꡬμΆνλ λ°©λ²μ 보μ¬μ€λλ€. λͺ¨λ κ²μ μ²μλΆν° μμ±νλ λμ , Puzzle 16μ
matmul_idiomatic_tiledκ³Ό Puzzle 18μsoftmax_kernelμ νμ©νμ¬ λͺ¨λν GPU 컀λ μ€κ³μ κ°λ ₯ν¨μ 보μ¬μ€λλ€.
ν΅μ¬ κ°λ
- μνμ€ μ²λ¦¬λ₯Ό μν λ²‘ν° μ΄ν μ λ©μ»€λμ¦
- 컀λ μ¬μ¬μ©: Puzzle 16κ³Ό Puzzle 18μ κ²μ¦λ ꡬν νμ©
- 곡μ λ©λͺ¨λ¦¬ tilingμ νμ©ν ν¨μ¨μ μΈ νλ ¬ κ³±μ
- λ²νΌ ν λΉμ μ΅μννλ λ©λͺ¨λ¦¬ μ΅μ ν ν μ νν λ³ν
- μ¬λ¬ μ΅μ ν 컀λμ λ¨μΌ μ°μ°μΌλ‘ ν΅ν©
- λ€μ€ μ λ ₯μ μ§μνλ 컀μ€ν MAX κ·Έλν μ°μ°
- νΈνμ±μ μν CPU ν΄λ°± ꡬν
μ€μ
- μνμ€ κΈΈμ΄: \(\text{SEQ_LEN} = 16~\) - μνμ€ λ΄ ν€/κ° λ²‘ν°μ μ
- λͺ¨λΈ μ°¨μ: \(\text{D} = 16~\) - κ° λ²‘ν°(쿼리, ν€, κ°)μ μ°¨μ
- λΈλ‘λΉ μ€λ λ μ: κ° μ»€λμ λ§κ² κ°λ³ μ΅μ ν
- 그리λ μ°¨μ: λ€μν νλ ¬ ν¬κΈ°λ₯Ό ν¨μ¨μ μΌλ‘ μ²λ¦¬νλλ‘ λμ μΌλ‘ κ³μ°
- 곡μ λ©λͺ¨λ¦¬: μ μΉ, matmul, μννΈλ§₯μ€ μ»€λμμ μ±λ₯μ μν΄ νμ©
λ μ΄μμ μ€μ :
- 쿼리 ν
μ:
Layout.row_major(d) - ν€ ν
μ:
Layout.row_major(seq_len, d) - κ° ν
μ:
Layout.row_major(seq_len, d) - μΆλ ₯ ν
μ:
Layout.row_major(d) - 컀μ€ν
op νλΌλ―Έν°:
{"seq_len": seq_len, "d": d, "dtype": dtype}
μ΄ νΌμ¦μ ν΅μ¬ μμλ λ€μκ³Ό κ°μ΅λλ€:
- λ€μ€ 컀λ μ€μΌμ€νΈλ μ΄μ : μ μΉ, matmul, μννΈλ§₯μ€ μ°μ°μ κ²°ν©
- λ©λͺ¨λ¦¬ μ΅μ ν: νν λ³ν μ°μ°κ³Ό λ²νΌ μ¬μ¬μ©μΌλ‘ λ©λͺ¨λ¦¬ ν λΉ μ΅μν
- μμΉ μμ μ±: Puzzle 18μ κ²μ¦λ μννΈλ§₯μ€ κ΅¬ν νμ©
- μ±λ₯ μ΅μ ν: λͺ¨λ νλ ¬ μ°μ°μ Puzzle 16μ νμΌλ§ μκ³ λ¦¬μ¦ μ¬μ©
- λ€μ€ μ λ ₯ μ°μ°: λ¨μΌ 컀μ€ν opμμ μΈ κ°μ μ λ ₯ ν μ(Q, K, V) μ²λ¦¬
μ΄ν μ 컀μ€ν μ°μ°μ λ€μκ³Ό κ°μ μΌμ μνν©λλ€:
- νμ΄μ¬μμ 쿼리, ν€, κ° ν μλ₯Ό μ λ ₯μΌλ‘ λ°κΈ°
- μ΅μ νλ 컀λμ μ¬μ©νμ¬ GPUμμ ν¨μ¨μ μΌλ‘ μ²λ¦¬
- μ΄ν μ κ°μ€ μΆλ ₯ λ²‘ν° λ°ν
- NumPy μ°Έμ‘° ꡬν κ²°κ³Όμ μΌμΉ
μμ±ν μ½λ
μ΄ νΌμ¦μ μμ±νλ €λ©΄ Puzzle 16μ νμΌλ§ matmul 컀λκ³Ό Puzzle 18μ μννΈλ§₯μ€ μ»€λμ νμ©ν©λλ€. 곡μ λ©λͺ¨λ¦¬λ₯Ό μ¬μ©νμ¬ Mojo νμΌμμ μ μΉ μ»€λλ§ κ΅¬ννλ©΄ λ©λλ€.
1. μ μΉ μ»€λ ꡬννκΈ°
fn transpose_kernel[
layout_in: Layout, # Layout for input matrix (seq_len, d)
layout_out: Layout, # Layout for output matrix (d, seq_len)
rows: Int,
cols: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[dtype, layout_out, MutAnyOrigin],
inp: LayoutTensor[dtype, layout_in, ImmutAnyOrigin],
):
# FILL ME IN (roughly 18 lines)
...
μ 체 νμΌ λ³΄κΈ°: problems/p19/op/attention.mojo
ν
μ μΉ μ»€λ ꡬν κ°μ΄λ:
-
곡μ λ©λͺ¨λ¦¬ μ€μ :
LayoutTensor[dtype, Layout.row_major(TRANSPOSE_BLOCK_DIM_XY, TRANSPOSE_BLOCK_DIM_XY), MutAnyOrigin, address_space = AddressSpace.SHARED].stack_allocation()μ μ¬μ©νμ¬TRANSPOSE_BLOCK_DIM_XYΓTRANSPOSE_BLOCK_DIM_XYν¬κΈ°μ μ μ¬κ°ν 곡μ λ©λͺ¨λ¦¬ νμΌμ μμ±ν©λλ€. μ΄λ₯Ό ν΅ν΄ μ€λ λ κ° ν¨μ¨μ μΈ λ°μ΄ν° κ΅νμ΄ κ°λ₯ν©λλ€. -
μ€λ λ μΈλ±μ±: μ€λ λλ₯Ό νλ ¬ μμμ λ§€νν©λλ€:
local_row = thread_idx.y,local_col = thread_idx.x(λΈλ‘ λ΄ μμΉ)global_row = block_idx.y * TRANSPOSE_BLOCK_DIM_XY + local_row(μ 체 νλ ¬μμμ μμΉ)
-
2λ¨κ³ μ°μ°:
- 1λ¨κ³: μ μ λ©λͺ¨λ¦¬μμ 곡μ λ©λͺ¨λ¦¬λ‘ μΌλ° μΈλ±μ±μΌλ‘ λ°μ΄ν°λ₯Ό λ‘λν©λλ€
- 2λ¨κ³: 곡μ λ©λͺ¨λ¦¬μμ μ μ λ©λͺ¨λ¦¬λ‘ λ€λ°κΎΌ μΈλ±μ±μΌλ‘ λ°μ΄ν°λ₯Ό μ μ₯ν©λλ€
-
νμ λκΈ°ν: λ‘λμ μ μ₯ μ¬μ΄μ
barrier()λ₯Ό νΈμΆνμ¬ λͺ¨λ μ€λ λκ° λ‘λλ₯Ό μλ£ν νμμΌ μ μ₯μ μμνλλ‘ λ³΄μ₯ν©λλ€ -
μ μΉμ ν΅μ¬: μ μΉλ λ€λ°κΎΌ μΈλ±μ±μ ν΅ν΄ μ΄λ£¨μ΄μ§λλ€:
shared_tile[local_row, local_col]λμshared_tile[local_col, local_row]λ₯Ό μ¬μ©ν©λλ€ -
κ²½κ³ μ²λ¦¬: μ μ λ©λͺ¨λ¦¬ μ κ·Ό μ κ²½κ³ κ²μ¬λ₯Ό μννμ¬
TRANSPOSE_BLOCK_DIM_XYxTRANSPOSE_BLOCK_DIM_XYλ‘ μ νν λλμ΄μ§μ§ μλ νλ ¬μμ λ²μλ₯Ό λ²μ΄λ μ½κΈ°/μ°κΈ°λ₯Ό λ°©μ§ν©λλ€ -
λ©λͺ¨λ¦¬ λ³ν©: μ΄ ν¨ν΄μ μ½κΈ°μ μ°κΈ° λͺ¨λ λ³ν©λλλ‘ λ³΄μ₯νμ¬ μ΅μ μ λ©λͺ¨λ¦¬ λμνμ νμ©ν©λλ€
2. μ΄ν μ μ€μΌμ€νΈλ μ΄μ
var gpu_ctx = rebind[DeviceContext](ctx[])
# Define layouts for matrix multiplication
# Q reshaped to (1, d)
comptime layout_q_2d = Layout.row_major(1, d)
# K^T is (d, seq_len)
comptime layout_k_t = Layout.row_major(d, seq_len)
# Scores as (1, seq_len)
comptime layout_scores_2d = Layout.row_major(1, seq_len)
# Weights as (1, seq_len)
comptime layout_weights_2d = Layout.row_major(1, seq_len)
# Result as (1, d)
comptime layout_result_2d = Layout.row_major(1, d)
# Transpose implementation limited to square (TRANSPOSE_BLOCK_DIM_XY x TRANSPOSE_BLOCK_DIM_XY) thread blocks
comptime transpose_threads_per_block = (
TRANSPOSE_BLOCK_DIM_XY,
TRANSPOSE_BLOCK_DIM_XY,
)
# Tile over the K (seq_len, d) matrix
comptime transpose_blocks_per_grid = (
(d + TRANSPOSE_BLOCK_DIM_XY - 1) // TRANSPOSE_BLOCK_DIM_XY,
(seq_len + TRANSPOSE_BLOCK_DIM_XY - 1)
// TRANSPOSE_BLOCK_DIM_XY,
)
# Matmul implementation limited to square (MATMUL_BLOCK_DIM_XY x MATMUL_BLOCK_DIM_XY) thread blocks
comptime matmul_threads_per_block = (
MATMUL_BLOCK_DIM_XY,
MATMUL_BLOCK_DIM_XY,
)
# seq_len outputs ( Q @ K^T = (1, d) @ (d, seq_len) -> (1, seq_len) ) with one thread per output
comptime scores_blocks_per_grid = (
seq_len + MATMUL_BLOCK_DIM_XY - 1
) // MATMUL_BLOCK_DIM_XY
comptime softmax_threads = SOFTMAX_BLOCK_DIM_X
comptime softmax_blocks_per_grid = 1
# d outputs ( weights @ V = (1, seq_len) @ (seq_len, d) -> (1, d) ) with one thread per output
comptime result_blocks_per_grid = (
d + MATMUL_BLOCK_DIM_XY - 1
) // MATMUL_BLOCK_DIM_XY
# Allocate minimal temporary buffers - reuse same buffer for different shapes
k_t_buf = gpu_ctx.enqueue_create_buffer[dtype](
seq_len * d
) # K^T as (d, seq_len)
scores_weights_buf = gpu_ctx.enqueue_create_buffer[dtype](
seq_len
) # Reused for scores and weights
k_t = LayoutTensor[dtype, layout_k_t, MutAnyOrigin](k_t_buf)
# Step 1: Reshape Q from (d,) to (1, d) - no buffer needed
# FILL ME IN 1 line
# Step 2: Transpose K from (seq_len, d) to K^T (d, seq_len)
# FILL ME IN 1 function call
# Step 3: Compute attention scores using matmul: Q @ K^T = (1, d) @ (d, seq_len) -> (1, seq_len)
# This computes Q Β· K^T[i] = Q Β· K[i] for each column i of K^T (which is row i of K)
# Reuse scores_weights_buf as (1, seq_len) for scores
# FILL ME IN 2 lines
# Step 4: Reshape scores from (1, seq_len) to (seq_len,) for softmax
# FILL ME IN 1 line
# Step 5: Apply softmax to get attention weights
# FILL ME IN 1 function call
# Step 6: Reshape weights from (seq_len,) to (1, seq_len) for final matmul
# FILL ME IN 1 line
# Step 7: Compute final result using matmul: weights @ V = (1, seq_len) @ (seq_len, d) -> (1, d)
# Reuse out_tensor reshaped as (1, d) for result
# FILL ME IN 2 lines
μ 체 νμΌ λ³΄κΈ°: problems/p19/op/attention.mojo
컀λ ν μ€νΈ
pixi run p19
pixi run -e amd p19
pixi run -e apple p19
uv run poe p19
μ±κ³΅νλ©΄ CPUμ GPUμμ λ€μκ³Ό λΉμ·ν μΆλ ₯μ λ³Ό μ μμ΅λλ€:
Input shapes: Q=(16,), K=(16, 16), V=(16, 16)
Sample Q values: [ 0.04967142 -0.01382643 0.06476886 0.15230298 -0.02341534]
Sample K[0] values: [-0.10128311 0.03142473 -0.09080241 -0.14123037 0.14656489]
Sample V[0] values: [ 0.11631638 0.00102331 -0.09815087 0.04621035 0.01990597]
================================================================================
STEP-BY-STEP VECTOR ATTENTION COMPUTATION DEBUG
================================================================================
1. INPUT SHAPES:
Q shape: (16,) (query vector)
K shape: (16, 16) (key matrix)
V shape: (16, 16) (value matrix)
Q[:5]: [ 0.04967142 -0.01382643 0.06476886 0.15230298 -0.02341534]
2. ATTENTION SCORES (K[i] Β· Q):
Scores shape: (16,)
Scores[:5]: [-0.03479404 -0.01563787 0.04834607 0.06764711 0.04001468]
Min: -0.061636, Max: 0.067647
Manual verification:
Q Β· K[0] = K[0] Β· Q = -0.034794 (computed: -0.034794)
Q Β· K[1] = K[1] Β· Q = -0.015638 (computed: -0.015638)
Q Β· K[2] = K[2] Β· Q = 0.048346 (computed: 0.048346)
3. SOFTMAX:
Max score: 0.067647
Attention weights shape: (16,)
Attention weights[:5]: [0.05981331 0.06097015 0.06499878 0.0662655 0.06445949]
Sum: 1.000000 (should be 1.0)
4. WEIGHTED SUM OF VALUES:
Output shape: (16,)
Output[:5]: [-0.00935538 -0.0243433 0.00306551 0.02346884 0.019306 ]
Output norm: 0.092764
Manual output[:5]: [-0.00935538 -0.0243433 0.00306551 0.02346884 0.019306 ]
Match: True
================================================================================
TESTING INDIVIDUAL OPERATIONS
================================================================================
Test 1: Vector Dot Product
a Β· b = 3.000000
Test 2: Matrix-Vector Multiplication
M @ v = [ 3. 7. 11.]
Test 3: Softmax
Input: [1. 2. 3. 4.]
Softmax: [0.0320586 0.08714432 0.2368828 0.6439143 ]
Sum: 1.000000
================================================================================
TESTING FULL ATTENTION
================================================================================
Compiling attention graph on Device(type=cpu,id=0)
Executing attention on Device(type=cpu,id=0)
====================================================================================================
CPU attention output[:5]: [-0.00935538 -0.02434331 0.00306551 0.02346884 0.019306 ]
CPU matches NumPy: True
Compiling attention graph on Device(type=gpu,id=0)
Executing attention on Device(type=gpu,id=0)
====================================================================================================
GPU attention output[:5]: [-0.00935538 -0.0243433 0.00306551 0.02346884 0.019306 ]
Expected output[:5]: [-0.00935538 -0.0243433 0.00306551 0.02346884 0.019306 ]
GPU matches NumPy: True
================================================================================
FINAL VERIFICATION
================================================================================
β CPU implementation PASSED
β GPU implementation PASSED
Output vector norms:
CPU: 0.092764
GPU: 0.092764
Expected: 0.092764
μ΄ μΆλ ₯μ 컀μ€ν MAX κ·Έλν μ°μ°μ΄ μ΄ν μ μκ³ λ¦¬μ¦μ μ¬λ°λ₯΄κ² ꡬννμ¬ NumPy μ°Έμ‘° ꡬνκ³Ό μΌμΉνλ κ²°κ³Όλ₯Ό μμ±νμμ 보μ¬μ€λλ€.
μ루μ
μ΄ νΌμ¦μ νλ €λ©΄ Mojoμμ μ μΉ μ»€λμ ꡬννκ³ μ΄ν μ 컀μ€ν μ°μ°μ μν νμ΄μ¬ κ·Έλν μ μλ₯Ό μμ±ν΄μΌ ν©λλ€. μ΄ νΌμ¦μ μ΄μ νΌμ¦μ κ°λ λ€μ κΈ°λ°μΌλ‘, Puzzle 16μ νμΌλ§ νλ ¬ κ³±μ κ³Ό Puzzle 18μ μννΈλ§₯μ€λ₯Ό κ²°ν©νμ¬ μμ ν μ΄ν μ λ©μ»€λμ¦μ ꡬμ±ν©λλ€.
μ¬μ¬μ© 컀λ
ꡬνμμ λ€μμ κ²μ¦λ 컀λλ€μ μ§μ νμ©ν©λλ€:
matmul_idiomatic_tiled(Puzzle 16) - \(Q \times K^T\)μ \(\text{weights} \times V\) μ°μ° λͺ¨λλ₯Ό μνsoftmax_kernel(Puzzle 18) - μμΉμ μΌλ‘ μμ μ μΈ μ΄ν μ κ°μ€μΉ κ³μ° μ 곡
μ΄λ λͺ¨λν GPU μν€ν μ²μ μ’μ μμμ λλ€: λ¨μΌ ꡬνμ²΄κ° μλ, κ²μ¦λ μ΅μ ν μ»΄ν¬λνΈλ₯Ό μ€μΌμ€νΈλ μ΄μ νμ¬ λ³΅μ‘ν μ κ²½λ§ μ°μ°μ ꡬμΆν©λλ€.
μ΄ν μ μ°μ°μ νμ€μ μΈ μνμ μ μλ₯Ό λ°λ¦ λλ€:
$$\Large \text{Attention}(Q, K, V) = \text{softmax}(Q \cdot K^T) \cdot V$$
μμ λΆμ:
- \(Q \cdot K^T~\): 쿼리-ν€ μ μ¬λ μ μ, νν: \((1, \text{seq_len})\)
- \(\text{softmax}(\cdot)~\): μ μλ₯Ό νλ₯ λ‘ μ κ·ν, νν: \((1, \text{seq_len})\)
- \(\text{weights} \cdot V~\): κ°μ κ°μ€ κ²°ν©, νν: \((1, d)\)
μ΄ κ³Όμ μλ μ΄μ νΌμ¦μ GPU 컀λμ νμ©νμ¬ μ΅μ ννλ μ¬λ¬ μ°μ° λ¨κ³κ° ν¬ν¨λ©λλ€.
1. μ μΉ μ»€λ ꡬν
fn transpose_kernel[
layout_in: Layout, # Layout for input matrix (seq_len, d)
layout_out: Layout, # Layout for output matrix (d, seq_len)
rows: Int,
cols: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[dtype, layout_out, MutAnyOrigin],
inp: LayoutTensor[dtype, layout_in, ImmutAnyOrigin],
):
"""Transpose matrix using shared memory tiling for coalesced access."""
shared_tile = LayoutTensor[
dtype,
Layout.row_major(TRANSPOSE_BLOCK_DIM_XY, TRANSPOSE_BLOCK_DIM_XY),
MutAnyOrigin,
address_space = AddressSpace.SHARED,
].stack_allocation()
local_row = Int(thread_idx.y)
local_col = Int(thread_idx.x)
global_row = Int(block_idx.y) * TRANSPOSE_BLOCK_DIM_XY + local_row
global_col = Int(block_idx.x) * TRANSPOSE_BLOCK_DIM_XY + local_col
if global_row < rows and global_col < cols:
shared_tile[local_row, local_col] = inp[global_row, global_col]
barrier()
out_row = Int(block_idx.x) * TRANSPOSE_BLOCK_DIM_XY + local_row
out_col = Int(block_idx.y) * TRANSPOSE_BLOCK_DIM_XY + local_col
# Store data from shared memory to global memory (coalesced write)
# Note: we transpose the shared memory access pattern
if out_row < cols and out_col < rows:
output[out_row, out_col] = shared_tile[local_col, local_row]
μ μΉ μ»€λμ 곡μ λ©λͺ¨λ¦¬ tilingμ μ¬μ©νμ¬ λ³ν© λ©λͺ¨λ¦¬ μ κ·Ό ν¨ν΄μ λ¬μ±ν©λλ€. ν΅μ¬ ꡬν λ΄μ©μ λ€μκ³Ό κ°μ΅λλ€:
ν΅μ¬ μ μΉ ν¨ν΄
# μΌλ° μΈλ±μ±μΌλ‘ λ‘λ
shared_tile[local_row, local_col] = inp[global_row, global_col]
barrier()
# λ€λ°κΎΌ μΈλ±μ±μΌλ‘ μ μ₯νμ¬ μ μΉ
output[out_row, out_col] = shared_tile[local_col, local_row]
μ μΉλ 곡μ λ©λͺ¨λ¦¬ μ κ·Όμμ λ€λ°κΎΌ μΈλ±μ±([local_row, local_col] λμ [local_col, local_row])κ³Ό μΆλ ₯ μμΉ μ§μ μ μν λ€λ°κΎΌ λΈλ‘ μ’νλ₯Ό ν΅ν΄ μ΄λ£¨μ΄μ§λλ€. μ΄λ₯Ό ν΅ν΄ μ½κΈ°μ μ°κΈ° λͺ¨λ λ³ν©μ μ μ§νλ©΄μ μ μΉ μ°μ°μ μνν©λλ€.
2. GPU 컀λ μ€μΌμ€νΈλ μ΄μ
# Step 1: Reshape Q from (d,) to (1, d) - no buffer needed
q_2d = q_tensor.reshape[layout_q_2d]()
# Step 2: Transpose K from (seq_len, d) to K^T (d, seq_len)\
comptime kernel = transpose_kernel[
layout_k, layout_k_t, seq_len, d, dtype
]
gpu_ctx.enqueue_function[kernel, kernel](
k_t,
k_tensor,
grid_dim=transpose_blocks_per_grid,
block_dim=transpose_threads_per_block,
)
# Step 3: Compute attention scores using matmul: Q @ K^T = (1, d) @ (d, seq_len) -> (1, seq_len)
# This computes Q Β· K^T[i] = Q Β· K[i] for each column i of K^T (which is row i of K)
# Reuse scores_weights_buf as (1, seq_len) for scores
scores_2d = LayoutTensor[dtype, layout_scores_2d, MutAnyOrigin](
scores_weights_buf
)
comptime kernel2 = matmul_idiomatic_tiled[
layout_q_2d,
layout_k_t,
layout_scores_2d,
1,
seq_len,
d,
dtype,
]
gpu_ctx.enqueue_function[kernel2, kernel2](
scores_2d,
q_2d,
k_t,
grid_dim=scores_blocks_per_grid,
block_dim=matmul_threads_per_block,
)
# Step 4: Reshape scores from (1, seq_len) to (seq_len,) for softmax
weights = scores_2d.reshape[layout_scores]()
# Step 5: Apply softmax to get attention weights
comptime kernel3 = softmax_gpu_kernel[layout_scores, seq_len, dtype]
gpu_ctx.enqueue_function[kernel3, kernel3](
weights,
weights,
grid_dim=softmax_blocks_per_grid,
block_dim=softmax_threads,
)
# Step 6: Reshape weights from (seq_len,) to (1, seq_len) for final matmul
weights_2d = weights.reshape[layout_weights_2d]()
# Step 7: Compute final result using matmul: weights @ V = (1, seq_len) @ (seq_len, d) -> (1, d)
# Reuse out_tensor reshaped as (1, d) for result
result_2d = output_tensor.reshape[layout_result_2d]()
comptime kernel4 = matmul_idiomatic_tiled[
layout_weights_2d,
layout_v,
layout_result_2d,
1,
d,
seq_len,
dtype,
]
gpu_ctx.enqueue_function[kernel4, kernel4](
result_2d,
weights_2d,
v_tensor,
grid_dim=result_blocks_per_grid,
block_dim=matmul_threads_per_block,
)
GPU μ€μΌμ€νΈλ μ΄μ μ μ κ΅ν 컀λ 체μ΄λκ³Ό μ λ‘ μΉ΄νΌ λ©λͺ¨λ¦¬ μ΅μ νλ₯Ό 보μ¬μ€λλ€:
κ³ κΈ λ©λͺ¨λ¦¬ μ΅μ ν μ λ΅
# μ λ‘ μΉ΄νΌ reshape - λ°μ΄ν° μ΄λ μμ΄ ν
μ shapeλ§ μ¬ν΄μ
q_2d = q_tensor.reshape[layout_q_2d]()
# μ κ·Ήμ μΈ λ²νΌ μ¬μ¬μ© - κ°μ λ©λͺ¨λ¦¬, λ€λ₯Έ ν΄μ
weights = scores_2d.reshape[layout_scores]()
ꡬνμ λ€μμ ν΅ν΄ μ΅λ λ©λͺ¨λ¦¬ ν¨μ¨μ λ¬μ±ν©λλ€:
- μ λ‘ μΉ΄νΌ νν λ³ν: λ©λͺ¨λ¦¬μμ λ°μ΄ν°λ₯Ό μ΄λνμ§ μκ³ ν μ ννλ₯Ό μ¬ν΄μ
- μ§λ₯μ λ²νΌ μ¬μ¬μ©: λμΌν
scores_weights_bufκ° μ μ \((1,\text{seq_len})\)μ κ°μ€μΉ \((\text{seq_len},)\) μ΄μ€ μ©λλ‘ νμ© - μ΅μ ν λΉ: λ¨ 2κ°μ μμ λ²νΌλ‘ μ 체 μ΄ν μ μ°μ° μν
- λ©λͺ¨λ¦¬ λ³ν©: λͺ¨λ μ°μ°μμ μ΅μ μ λ©λͺ¨λ¦¬ μ κ·Ό ν¨ν΄ μ μ§
μ λ΅μ 컀λ μ¬μ¬μ© ν¨ν΄
- 3λ¨κ³ & 7λ¨κ³: λ λ€ Puzzle 16μ
matmul_idiomatic_tiledνμ©- 3λ¨κ³: \(Q \times K^T\) β μ΄ν μ μ μ κ³μ° \((1,d) \times (d,\text{seq_len}) \rightarrow (1,\text{seq_len})\)
- 7λ¨κ³: \(\text{weights} \times V\) β μ΅μ’ κ°μ€ μΆλ ₯ \((1,\text{seq_len}) \times (\text{seq_len},d) \rightarrow (1,d)\)
- λ μ°μ° λͺ¨λ λ€μν νλ ¬ ν¬κΈ°λ₯Ό μμ νκ² μ²λ¦¬νκΈ° μν΄ κ²½κ³ κ²μ¬ ν¬ν¨
- 5λ¨κ³: Puzzle 18μ
softmax_kernelμ¬μ©- μμ μ μλ₯Ό μ κ·νλ νλ₯ λΆν¬λ‘ λ³ν
- μ΅λκ° μ°¨κ°κ³Ό λ³λ ¬ 리λμ μ ν΅ν μμΉ μμ μ± λ³΄μ₯
- \(\sum_{i} \text{weights}[i] = 1.0\) 보μ₯
μ΄λ λͺ¨λν GPU μν€ν μ²μ μ’μ μμμ λλ€: λ¨μΌ ꡬνμ²΄κ° μλ, κ²μ¦λ μ΅μ ν 컀λλ€μ μ€μΌμ€νΈλ μ΄μ νμ¬ λ³΅μ‘ν μ κ²½λ§ μ°μ°μ ꡬμΆν©λλ€!
ν΅μ¬ ꡬν μΈμ¬μ΄νΈ
λ©λͺ¨λ¦¬ μ΅μ ν μ λ΅
μ κ·Ήμ μΈ λ²νΌ μ¬μ¬μ©μΌλ‘ λ©λͺ¨λ¦¬ ν λΉμ μ΅μνν©λλ€:
# μ 체 μ°μ°μ νμν μμ λ²νΌλ λ¨ 2κ°
k_t_buf = gpu_ctx.enqueue_create_buffer[dtype](seq_len * d)
scores_weights_buf = gpu_ctx.enqueue_create_buffer[dtype](seq_len)
ν΅μ¬ μ΅μ ν ν¬μΈνΈ:
- λμΌν
scores_weights_bufκ° νν λ³ν μ°μ°μ ν΅ν΄ μ΄ν μ μ μμ κ°μ€μΉ λͺ¨λμ μ¬μ¬μ©λ©λλ€ - μ λ‘ μΉ΄νΌ ν μ νν λ³νμΌλ‘ λΆνμν λ°μ΄ν° μ΄λμ μ κ±°ν©λλ€
컀λ μ¬μ¬μ© μν€ν μ²
μ΄ νΌμ¦μ μΈ κ°μ§ νΉνλ 컀λμ κ²°ν©νμ¬ λͺ¨λν 컀λ μ€κ³λ₯Ό 보μ¬μ€λλ€:
matmul_idiomatic_tiled(2ν μ¬μ©) - \(Q \times K^T\)μ \(\text{weights} \times V\) μ°μ° λͺ¨λλ₯Ό μνsoftmax_kernel- λ³λ ¬ 리λμ μ νμ©νμ¬ μμΉμ μΌλ‘ μμ μ μΈ μ΄ν μ κ°μ€μΉ κ³μ°transpose_kernel- λ³ν© λ©λͺ¨λ¦¬ μ κ·ΌμΌλ‘ ν¨μ¨μ μΈ \(K^T\) κ³μ°
μν€ν μ²μ μ₯μ :
- μ‘°ν© κ°λ₯μ±: κ²μ¦λ μ»΄ν¬λνΈλ‘ 볡μ‘ν μ°μ° ꡬμΆ
- μ μ§λ³΄μμ±: κ° μ»€λμ΄ λͺ ννκ² μ μλ λ¨μΌ μν μν
- μ±λ₯: μ΄μ νΌμ¦μ κ³ λλ‘ μ΅μ νλ ꡬν νμ©
- νμ₯μ±: λͺ¨λν μ€κ³λ‘ λ ν° μ΄ν μ λ©μ»€λμ¦μΌλ‘ νμ₯ μ©μ΄
μ΄ κ΅¬νμ μ κ΅ν μ κ²½λ§ μ°μ°μ΄ λ¨μΌ ꡬνμ²΄κ° μλ, λ λ¨μνκ³ μ κ²μ¦λ GPU 컀λλ€μ μ€μΌμ€νΈλ μ΄μ νμ¬ κ΅¬μΆν μ μμμ 보μ¬μ€λλ€.