⚛️ 퓨전 vs 언퓨전 커널
개요
이 퍼즐에서는 LayerNorm과 Linear 연산에 대한 두 가지 접근 방식을 구현하고 비교하며, 커널 퓨전의 성능 이점을 탐구합니다:
- 언퓨전 방식: LayerNorm과 Linear를 별도의 연산으로 실행
- 퓨전 커널: LayerNorm과 Linear 연산을 하나의 GPU 커널로 결합
이 비교를 통해 커널 퓨전이 다음과 같은 방법으로 성능을 크게 개선할 수 있음을 보여줍니다:
- 메모리 대역폭 사용량 절감
- 커널 실행 오버헤드 최소화
- 캐시 활용도 향상
- 중간 결과 저장을 위한 메모리 할당 제거
핵심 개념
이 퍼즐에서 배울 내용:
- 여러 연산을 결합하는 커널 퓨전 기법
- 퓨전 연산을 통한 메모리 대역폭 최적화
- 서로 다른 커널 구현의 성능 벤치마킹
- 퓨전 연산에서의 수치 안정성
- PyTorch 커스텀 연산 통합
결합할 수학적 연산은 다음과 같습니다:
-
LayerNorm: \[\Large \text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \]
-
Linear: \[\Large \text{Linear}(x) = Wx + b \]
퓨전 연산으로 결합하면 다음을 계산합니다: \[\Large \text{Fused}(x) = W(\gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta) + b \]
LayerNorm 이해하기
LayerNorm은 심층 신경망의 학습을 안정화하고 가속하는 정규화 기법입니다. 구성 요소와 파라미터를 하나씩 살펴보겠습니다:
LayerNorm이 하는 일
-
정규화: LayerNorm은 각 샘플의 특성(은닉 차원, hidden dimension) 전체에 걸쳐 활성화 값을 독립적으로 정규화합니다. 구체적으로:
- 각 시퀀스 위치에서 은닉 차원에 대한 통계량을 계산합니다
- 배치의 각 샘플은 독립적으로 정규화됩니다
- 배치 차원에 대해 정규화하는 BatchNorm과는 다릅니다
-
파라미터:
- \(\gamma\) (scale): 네트워크가 각 특성의 최적 스케일을 학습할 수 있게 하는 학습 가능한 파라미터 벡터
- \(\beta\) (shift): 네트워크가 각 특성의 최적 이동량을 학습할 수 있게 하는 학습 가능한 파라미터 벡터
- \(\epsilon\): 0으로 나누는 것을 방지하기 위해 분산에 더하는 작은 상수 (1e-5)
LayerNorm의 실제 역할
LayerNorm은 심층 신경망에서 여러 중요한 기능을 수행합니다:
-
특성 표준화:
- 각 특성을 평균 0, 분산 1로 변환합니다
- 네트워크의 학습 과정을 더 안정적으로 만듭니다
- 학습 중 레이어 입력의 분포가 변하는 “내부 공변량 이동(internal covariate shift)” 문제를 방지합니다
-
기울기 흐름:
- 네트워크를 통한 기울기 흐름을 개선합니다
- 기울기 소실/폭발 문제를 방지합니다
- 더 높은 학습률을 사용할 수 있어 학습 효율이 향상됩니다
-
정규화 효과:
- 암묵적인 정규화 역할을 합니다
- 특성 분포를 정규화하여 과적합을 방지합니다
- 입력 변동에 대한 네트워크의 강건성을 높입니다
-
시퀀스 모델링:
- 트랜스포머 아키텍처에서 특히 효과적입니다
- 서로 다른 시퀀스 길이에서도 일관된 신호 크기를 유지합니다
- 가변 길이 시퀀스를 더 잘 처리할 수 있게 합니다
-
학습 역학:
- 학습 수렴을 가속합니다
- 세밀한 학습률 조정의 필요성을 줄입니다
- 가중치 초기화에 대한 네트워크의 민감도를 낮춥니다
수학적 구성 요소
-
평균 계산 (\(\mu\)): \[\Large \mu = \frac{1}{H} \sum_{i=1}^{H} x_i \]
- 은닉 차원(H)에 걸쳐 평균을 계산합니다
- 각 시퀀스 위치마다 고유한 평균을 가집니다
-
분산 계산 (\(\sigma^2\)): \[\Large \sigma^2 = \frac{1}{H} \sum_{i=1}^{H} (x_i - \mu)^2 \]
- 은닉 차원에 걸쳐 분산을 계산합니다
- 정규화된 값의 스케일링에 사용됩니다
-
정규화와 스케일링: \[\Large \text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \]
- 먼저 입력을 평균 0, 분산 1로 정규화합니다
- 그런 다음 학습 가능한 scale (\(\gamma\))과 shift (\(\beta\)) 파라미터를 적용합니다
- \(\odot\) 기호는 요소별 곱셈(아다마르 곱)을 나타냅니다
- 예를 들어, \(\gamma = [1.2, 0.8, 1.5]\)이고 정규화된 입력이 \([0.5, -0.3, 0.7]\)이면, \(\gamma \odot x = [0.6, -0.24, 1.05]\)입니다
LayerNorm이 중요한 이유
-
학습 안정성:
- 활성화 값이 너무 크거나 작아지는 것을 방지합니다
- 네트워크 전체에 걸쳐 일관된 신호 크기를 유지합니다
-
특성 학습:
- scale (\(\gamma\))과 shift (\(\beta\)) 파라미터를 통해 어떤 특성이 중요한지 학습할 수 있습니다
- 특정 특성을 무시하거나 강조하는 것을 효과적으로 학습할 수 있습니다
-
독립성:
- BatchNorm과 달리, LayerNorm의 통계량은 각 샘플에 대해 독립적으로 계산됩니다
- 가변 길이 시퀀스와 작은 배치 크기에 더 적합합니다
구성
- 배치 크기:
BATCH_SIZE = 4 - 시퀀스 길이:
SEQ_LEN = 4 - 은닉 차원:
HIDDEN_DIM = 8 - 출력 차원:
OUTPUT_DIM = 16 - 엡실론:
EPS = 1e-5 - 데이터 타입:
DType.float32
구현 방식
1. 언퓨전 구현
언퓨전 방식은 여러 커널을 사용하여 연산을 개별적으로 실행합니다. 이전 챕터에서 작성한 커널들을 살펴보겠습니다:
행렬 곱셈 커널
Puzzle 16: 행렬 곱셈 (MatMul)에서 사용한 타일링 행렬 곱셈 커널을 선형 변환에 재사용합니다. 이 커널은 다양한 행렬 크기를 안전하게 처리하기 위한 경계 검사를 포함합니다:
# Idiomatic tiled matmul from p19.mojo
fn matmul_idiomatic_tiled[
a_layout: Layout,
b_layout: Layout,
out_layout: Layout,
rows: Int,
cols: Int,
inner: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[dtype, out_layout, MutAnyOrigin],
a: LayoutTensor[dtype, a_layout, MutAnyOrigin],
b: LayoutTensor[dtype, b_layout, MutAnyOrigin],
):
"""Idiomatic tiled matrix multiplication from p19."""
local_row = thread_idx.y
local_col = thread_idx.x
tiled_row = Int(block_idx.y * MATMUL_BLOCK_DIM_XY + local_row)
tiled_col = Int(block_idx.x * MATMUL_BLOCK_DIM_XY + local_col)
# Get the tile of the output matrix that this thread block is responsible for
out_tile = output.tile[MATMUL_BLOCK_DIM_XY, MATMUL_BLOCK_DIM_XY](
Int(block_idx.y), Int(block_idx.x)
)
a_shared = LayoutTensor[
dtype,
Layout.row_major(MATMUL_BLOCK_DIM_XY, MATMUL_BLOCK_DIM_XY),
MutAnyOrigin,
address_space = AddressSpace.SHARED,
].stack_allocation()
b_shared = LayoutTensor[
dtype,
Layout.row_major(MATMUL_BLOCK_DIM_XY, MATMUL_BLOCK_DIM_XY),
MutAnyOrigin,
address_space = AddressSpace.SHARED,
].stack_allocation()
var acc: output.element_type = 0
comptime load_a_layout = Layout.row_major(
MATMUL_BLOCK_DIM_XY, MATMUL_BLOCK_DIM_XY
) # Coalesced loading
comptime load_b_layout = Layout.row_major(
MATMUL_BLOCK_DIM_XY, MATMUL_BLOCK_DIM_XY
) # Coalesced loading
@parameter
for idx in range((inner + MATMUL_BLOCK_DIM_XY - 1) // MATMUL_BLOCK_DIM_XY):
# Get tiles from A and B matrices
a_tile = a.tile[MATMUL_BLOCK_DIM_XY, MATMUL_BLOCK_DIM_XY](
Int(block_idx.y), idx
)
b_tile = b.tile[MATMUL_BLOCK_DIM_XY, MATMUL_BLOCK_DIM_XY](
idx, Int(block_idx.x)
)
# Asynchronously copy tiles to shared memory with consistent orientation
copy_dram_to_sram_async[
thread_layout=load_a_layout,
num_threads=MATMUL_NUM_THREADS,
block_dim_count=MATMUL_BLOCK_DIM_COUNT,
](a_shared, a_tile)
copy_dram_to_sram_async[
thread_layout=load_b_layout,
num_threads=MATMUL_NUM_THREADS,
block_dim_count=MATMUL_BLOCK_DIM_COUNT,
](b_shared, b_tile)
# Wait for all async copies to complete
async_copy_wait_all()
barrier()
# Compute partial matrix multiplication for this tile
@parameter
for k in range(MATMUL_BLOCK_DIM_XY):
if (
tiled_row < rows and tiled_col < cols
): # Only perform calculation for valid outputs
if k < a_tile.dim(
1
): # Only perform calculation on valid inputs
acc += a_shared[local_row, k] * b_shared[k, local_col]
barrier()
# Write final result with bounds checking (needed for variable matrix sizes)
if tiled_row < rows and tiled_col < cols:
out_tile[local_row, local_col] = acc
전치 커널
효율적인 메모리 접근 패턴을 위해 공유 메모리 타일링을 사용하는 전치 커널입니다:
fn transpose_kernel[
layout_in: Layout,
layout_out: Layout,
rows: UInt,
cols: UInt,
dtype: DType = DType.float32,
](
output: LayoutTensor[dtype, layout_out, MutAnyOrigin],
inp: LayoutTensor[dtype, layout_in, ImmutAnyOrigin],
):
"""Transpose matrix using shared memory tiling for coalesced access.
We will learn more about coalesced access in the next part.
"""
shared_tile = LayoutTensor[
dtype,
Layout.row_major(TRANSPOSE_BLOCK_DIM_XY, TRANSPOSE_BLOCK_DIM_XY),
MutAnyOrigin,
address_space = AddressSpace.SHARED,
].stack_allocation()
local_row = thread_idx.y
local_col = thread_idx.x
global_row = block_idx.y * TRANSPOSE_BLOCK_DIM_XY + local_row
global_col = block_idx.x * TRANSPOSE_BLOCK_DIM_XY + local_col
if global_row < rows and global_col < cols:
shared_tile[local_row, local_col] = inp[global_row, global_col]
barrier()
out_row = block_idx.x * TRANSPOSE_BLOCK_DIM_XY + local_row
out_col = block_idx.y * TRANSPOSE_BLOCK_DIM_XY + local_col
# Store data from shared memory to global memory (coalesced write)
# Note: we transpose the shared memory access pattern
if out_row < cols and out_col < rows:
output[out_row, out_col] = shared_tile[local_col, local_row]
Bias 합산 커널
Bias 항을 더하는 간단한 요소별 합산 커널입니다:
fn add_bias_kernel[
input_layout: Layout,
bias_layout: Layout,
output_layout: Layout,
batch_size: Int,
seq_len: Int,
output_dim: Int,
](
output: LayoutTensor[dtype, output_layout, MutAnyOrigin],
input: LayoutTensor[dtype, input_layout, MutAnyOrigin],
bias: LayoutTensor[dtype, bias_layout, ImmutAnyOrigin],
):
"""Simple bias addition."""
batch_idx = Int(block_idx.x)
seq_idx = Int(block_idx.y)
out_idx = Int(thread_idx.x)
if batch_idx >= batch_size or seq_idx >= seq_len or out_idx >= output_dim:
return
output[batch_idx, seq_idx, out_idx] = input[
batch_idx, seq_idx, out_idx
] + rebind[Scalar[dtype]](bias[out_idx])
LayerNorm 커널
이제 이 커널을 완성하여 LayerNorm 연산을 구현합니다. 다음이 필요합니다:
- 각 시퀀스 위치에 대한 평균 \(\mu\)과 분산 \(\sigma^2\) 계산
- 이 통계량을 사용하여 입력 정규화
- 스케일 \(\gamma\)과 시프트 \(\beta\) 파라미터 적용
fn layernorm_kernel[
input_layout: Layout,
ln_params_layout: Layout,
output_layout: Layout,
batch_size: Int,
seq_len: Int,
hidden_dim: Int,
](
output: LayoutTensor[dtype, output_layout, MutAnyOrigin],
input: LayoutTensor[dtype, input_layout, ImmutAnyOrigin],
ln_weight: LayoutTensor[dtype, ln_params_layout, ImmutAnyOrigin],
ln_bias: LayoutTensor[dtype, ln_params_layout, ImmutAnyOrigin],
):
batch_idx = Int(block_idx.x)
seq_idx = Int(block_idx.y)
hidden_idx = Int(thread_idx.x)
if (
batch_idx >= batch_size
or seq_idx >= seq_len
or hidden_idx >= hidden_dim
):
return
# Compute statistics for this sequence position (redundant but simple)
var sum_val: Scalar[dtype] = 0
var sq_sum: Scalar[dtype] = 0
# FILL ME IN (roughly 11 lines)
구현 단계:
- 먼저, 병렬 리덕션을 사용하여 평균과 분산을 계산합니다
- 그런 다음, 이 통계량으로 입력을 정규화합니다
- 마지막으로, 스케일과 시프트 파라미터를 적용합니다
언퓨전 방식의 특성:
- 여러 번의 커널 실행 (LayerNorm → MatMul → Bias)
- 연산 간 중간 텐서 할당
- 별도의 패스로 인한 메모리 대역폭 사용량 증가
- 관심사 분리가 명확한 간결한 구현
- 각 연산이 격리되어 디버깅이 용이
팁
-
스레드 구성:
- 시퀀스 위치당 하나의 스레드 블록 사용 (그리드:
[batch_size, seq_len]) - 각 스레드가 하나의 은닉 차원 요소를 처리
- 시퀀스당 통계량을 한 번만 계산하여 중복 연산 방지
- 시퀀스 위치당 하나의 스레드 블록 사용 (그리드:
-
메모리 접근:
- 입력 텐서:
[batch_idx, seq_idx, hidden_idx]로 접근 - 출력 텐서:
[batch_idx, seq_idx, hidden_idx]로 접근 - LayerNorm 파라미터:
[hidden_idx]로 접근
- 입력 텐서:
-
수치 안정성:
- 제곱근을 취하기 전에 엡실론(1e-5)을 더합니다
- 적절한 타입 캐스팅을 위해
rebind[Scalar[dtype]]사용 - 분산은 (sq_sum / hidden_dim) - (mean * mean)으로 계산
-
성능:
- 한 번의 패스로 평균과 분산을 동시에 계산
- 계산된 통계량을 시퀀스 내 모든 요소에 재사용
- 불필요한 메모리 배리어 방지
코드 실행
언퓨전 구현을 테스트하려면 다음을 실행하세요:
pixi run p22 --unfused
pixi run -e amd p22 --unfused
uv run poe p22 --unfused
출력은 다음과 같습니다:
Testing with dimensions: [4, 4, 8] -> [4, 4, 16]
✅ Loaded Mojo operations library
============================================================
Puzzle 22: UNFUSED Algorithm Test & Benchmark
============================================================
🧪 Correctness Testing for UNFUSED Algorithm
====================================================
Testing Reference PyTorch Implementation
-----------------------------------------------
✅ Reference PyTorch
Max difference: 0.00e+00
Result: ✅ CORRECT
Testing CPU Implementation
---------------------------------
✅ Using Mojo fused kernel (CPU)
Max difference: 1.86e-08
Result: ✅ CORRECT
Testing GPU Unfused Implementation
-----------------------------------------
✅ Using Mojo unfused kernel (GPU)
Max difference: 1.86e-08
Result: ✅ CORRECT
Correctness Summary:
- Reference: ✅ CORRECT
- CPU: ✅ CORRECT
- GPU unfused: ✅ CORRECT
Overall Correctness: ✅ ALL CORRECT
Benchmarking CPU vs GPU UNFUSED
------------------------------------------
Testing CPU performance...
CPU: 3173.70ms (50 iterations)
Testing GPU unfused performance...
GPU unfused: 3183.57ms (50 iterations)
GPU unfused vs CPU: 1.00x slower
CPU wins (GPU overhead > computation benefit)
UNFUSED Algorithm Test Completed!
솔루션
fn layernorm_kernel[
input_layout: Layout,
ln_params_layout: Layout,
output_layout: Layout,
batch_size: Int,
seq_len: Int,
hidden_dim: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[dtype, output_layout, MutAnyOrigin],
input: LayoutTensor[dtype, input_layout, ImmutAnyOrigin],
ln_weight: LayoutTensor[dtype, ln_params_layout, ImmutAnyOrigin],
ln_bias: LayoutTensor[dtype, ln_params_layout, ImmutAnyOrigin],
):
batch_idx = Int(block_idx.x)
seq_idx = Int(block_idx.y)
hidden_idx = Int(thread_idx.x)
if (
batch_idx >= batch_size
or seq_idx >= seq_len
or hidden_idx >= hidden_dim
):
return
# Compute statistics for this sequence position (redundant but simple)
var sum_val: Scalar[dtype] = 0
var sq_sum: Scalar[dtype] = 0
@parameter
for h in range(hidden_dim):
val = input[batch_idx, seq_idx, h]
sum_val += rebind[Scalar[dtype]](val)
sq_sum += rebind[Scalar[dtype]](val * val)
mean_val = sum_val / hidden_dim
var_val = (sq_sum / hidden_dim) - (mean_val * mean_val)
inv_std = 1.0 / sqrt(var_val + 1e-5)
# Apply LayerNorm to this element
input_val = input[batch_idx, seq_idx, hidden_idx]
normalized = (input_val - mean_val) * inv_std * rebind[Scalar[dtype]](
ln_weight[hidden_idx]
) + rebind[Scalar[dtype]](ln_bias[hidden_idx])
output[batch_idx, seq_idx, hidden_idx] = normalized
언퓨전 구현은 각 스레드가 출력 텐서의 하나의 요소를 처리하는 직관적인 방식을 따릅니다. 핵심 구성 요소를 하나씩 살펴보겠습니다:
-
스레드와 블록 구성:
batch_idx = block_idx.x seq_idx = block_idx.y hidden_idx = thread_idx.x-
각 스레드 블록이 배치 내 하나의 시퀀스 위치를 처리합니다
-
그리드 차원:
[batch_size, seq_len] -
각 스레드가 은닉 차원의 하나의 요소를 처리합니다
-
인덱스가 범위를 벗어나면 조기 반환합니다:
if (batch_idx >= batch_size or seq_idx >= seq_len or hidden_idx >= hidden_dim): return
-
-
통계량 계산:
var sum_val: Scalar[dtype] = 0 var sq_sum: Scalar[dtype] = 0 @parameter for h in range(hidden_dim): val = input[batch_idx, seq_idx, h] sum_val += rebind[Scalar[dtype]](val) sq_sum += rebind[Scalar[dtype]](val * val)-
한 번의 패스로 합계와 제곱합을 동시에 계산합니다
-
컴파일 타임 루프 전개를 위해
@parameter를 사용합니다 -
rebind[Scalar[dtype]]로 적절한 타입 캐스팅을 수행합니다 -
평균과 분산을 계산합니다:
mean_val = sum_val / hidden_dim var_val = (sq_sum / hidden_dim) - (mean_val * mean_val) inv_std = 1.0 / sqrt(var_val + 1e-5)
-
-
정규화와 스케일링:
input_val = input[batch_idx, seq_idx, hidden_idx] normalized = (input_val - mean_val) * inv_std * rebind[Scalar[dtype]]( ln_weight[hidden_idx] ) + rebind[Scalar[dtype]](ln_bias[hidden_idx]) output[batch_idx, seq_idx, hidden_idx] = normalized- 정규화를 적용합니다: \[\Large \text{normalized} = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \]
- 학습 가능한 파라미터
γ(ln_weight)로 스케일링합니다 - 학습 가능한 bias
β(ln_bias)를 더합니다 - 결과를 출력 텐서에 저장합니다
-
성능 특성:
- 각 스레드가 독립적으로 통계량을 계산합니다
- 공유 메모리 사용 없음 (간단하지만 덜 효율적)
- 메모리 접근 패턴:
- 입력:
[batch_idx, seq_idx, h] - 출력:
[batch_idx, seq_idx, hidden_idx] - 파라미터:
[hidden_idx]
- 입력:
- 다음을 통해 수치 안정성을 보장합니다:
- 제곱근 전에 엡실론(1e-5) 추가
- 적절한 타입 캐스팅 사용
- 수치적으로 안정적인 방식으로 분산 계산
-
구현 세부 사항:
-
타입 안전성:
- 중간 계산에
Scalar[dtype]사용 - 적절한 타입 캐스팅을 위해
rebind[Scalar[dtype]]사용 - 일관된 부동소수점 정밀도 보장
- 중간 계산에
-
메모리 접근:
- 입력 텐서에서 병합 읽기
- 출력 텐서에 병합 쓰기
- LayerNorm 파라미터에 순차적 접근
-
연산 흐름:
- 통계량 계산: \[\Large O(H) \text{ operations per thread} \]
- 정규화: \[\Large O(1) \text{ operations per thread} \]
- 전체 복잡도: \[\Large O(H) \text{ per output element} \]
-
한계점:
- 통계량의 중복 계산
- 중간 결과를 위한 공유 메모리 없음
- 높은 메모리 대역폭 사용량
- 여러 번의 커널 실행 필요
-
이 구현은 정확하지만 성능 면에서 최적이 아니며, 벤치마크 결과에서 CPU 버전보다 약간 느린 것을 확인할 수 있습니다. 퓨전 구현에서는 다음을 통해 이러한 성능 한계를 해결합니다:
- 시퀀스당 통계량을 한 번만 계산
- 정규화된 값 재사용
- 메모리 트래픽 감소
- 중간 텐서 할당 제거
2. 퓨전 커널 구현
퓨전 커널은 LayerNorm과 Linear 연산을 하나의 GPU 커널로 결합합니다:
fn minimal_fused_kernel[
input_layout: Layout,
ln_params_layout: Layout,
weight_layout: Layout,
bias_layout: Layout,
output_layout: Layout,
batch_size: Int,
seq_len: Int,
hidden_dim: Int,
output_dim: Int,
](
output: LayoutTensor[dtype, output_layout, MutAnyOrigin],
input: LayoutTensor[dtype, input_layout, ImmutAnyOrigin],
ln_weight: LayoutTensor[dtype, ln_params_layout, ImmutAnyOrigin],
ln_bias: LayoutTensor[dtype, ln_params_layout, ImmutAnyOrigin],
linear_weight: LayoutTensor[dtype, weight_layout, ImmutAnyOrigin],
linear_bias: LayoutTensor[dtype, bias_layout, ImmutAnyOrigin],
):
"""Minimal fused kernel - one thread per sequence position to avoid redundancy.
"""
# Grid: (batch_size, seq_len) - one thread block per sequence position
# Block: (1,) - single thread per sequence position to avoid redundant computation
batch_idx = Int(block_idx.x)
seq_idx = Int(block_idx.y)
if batch_idx >= batch_size or seq_idx >= seq_len:
return
# Step 1: Compute LayerNorm statistics once per sequence position
# FILL IN roughly 10 lines
# Step 2: Compute all outputs for this sequence position
# FILL IN roughly 10 lines
핵심 최적화:
- 두 번 대신 한 번의 커널 실행
- 중간 결과를 위한 공유 메모리 활용
- 병합 메모리 접근 패턴
- 메모리 대역폭 사용량 절감
- 중간 텐서 할당 불필요
팁
-
스레드 구성:
- 시퀀스 위치당 하나의 스레드 블록 (그리드:
[batch_size, seq_len]) - 중복을 방지하기 위해 시퀀스 위치당 단일 스레드
- 각 시퀀스 위치의 모든 출력을 하나의 스레드에서 계산
- 시퀀스 위치당 하나의 스레드 블록 (그리드:
-
메모리 접근:
- 입력 텐서:
[batch_idx, seq_idx, h]로 접근 - 출력 텐서:
[batch_idx, seq_idx, out_idx]로 접근 - 가중치: 선형 레이어에서
[out_idx, h]로 접근
- 입력 텐서:
-
연산 흐름:
- 시퀀스당 LayerNorm 통계량을 한 번만 계산
- 모든 출력 차원에 정규화된 값을 재사용
- 정규화와 선형 변환을 결합
-
성능:
- 통계량의 중복 계산 방지
- 연산을 결합하여 메모리 트래픽 최소화
rebind[Scalar[dtype]]로 적절한 타입 캐스팅 사용
코드 실행
퓨전 구현을 테스트하려면 다음을 실행하세요:
pixi run p22 --fused
pixi run -e amd p22 --fused
uv run poe p22 --fused
출력은 다음과 같습니다:
Testing with dimensions: [4, 4, 8] -> [4, 4, 16]
✅ Loaded Mojo operations library
============================================================
Puzzle 22: FUSED Algorithm Test & Benchmark
============================================================
🧪 Correctness Testing for FUSED Algorithm
==================================================
Testing Reference PyTorch Implementation
-----------------------------------------------
✅ Reference PyTorch
Max difference: 0.00e+00
Result: ✅ CORRECT
Testing CPU Implementation
---------------------------------
✅ Using Mojo fused kernel (CPU)
Max difference: 1.86e-08
Result: ✅ CORRECT
Testing GPU Fused Implementation
---------------------------------------
✅ Using Mojo fused kernel (GPU)
Max difference: 1.86e-08
Result: ✅ CORRECT
Correctness Summary:
- Reference: ✅ CORRECT
- CPU: ✅ CORRECT
- GPU fused: ✅ CORRECT
Overall Correctness: ✅ ALL CORRECT
⚡ Benchmarking CPU vs GPU FUSED
----------------------------------------
Testing CPU performance...
CPU: 3144.75ms (50 iterations)
Testing GPU fused performance...
GPU fused: 3116.11ms (50 iterations)
GPU fused vs CPU: 1.01x faster
GPU fused wins!
FUSED Algorithm Test Completed!
솔루션
fn minimal_fused_kernel[
input_layout: Layout,
ln_params_layout: Layout,
weight_layout: Layout,
bias_layout: Layout,
output_layout: Layout,
batch_size: Int,
seq_len: Int,
hidden_dim: Int,
output_dim: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[dtype, output_layout, MutAnyOrigin],
input: LayoutTensor[dtype, input_layout, ImmutAnyOrigin],
ln_weight: LayoutTensor[dtype, ln_params_layout, ImmutAnyOrigin],
ln_bias: LayoutTensor[dtype, ln_params_layout, ImmutAnyOrigin],
linear_weight: LayoutTensor[dtype, weight_layout, ImmutAnyOrigin],
linear_bias: LayoutTensor[dtype, bias_layout, ImmutAnyOrigin],
):
"""Minimal fused kernel - one thread per sequence position to avoid redundancy.
"""
# Grid: (batch_size, seq_len) - one thread block per sequence position
# Block: (1,) - single thread per sequence position to avoid redundant computation
batch_idx = Int(block_idx.x)
seq_idx = Int(block_idx.y)
if batch_idx >= batch_size or seq_idx >= seq_len:
return
# Step 1: Compute LayerNorm statistics once per sequence position
var sum_val: Scalar[dtype] = 0
var sq_sum: Scalar[dtype] = 0
@parameter
for h in range(hidden_dim):
val = input[batch_idx, seq_idx, h]
sum_val += rebind[Scalar[dtype]](val)
sq_sum += rebind[Scalar[dtype]](val * val)
mean_val = sum_val / hidden_dim
var_val = (sq_sum / hidden_dim) - (mean_val * mean_val)
inv_std = 1.0 / sqrt(var_val + 1e-5)
# Step 2: Compute all outputs for this sequence position
@parameter
for out_idx in range(output_dim):
var acc: Scalar[dtype] = 0
@parameter
for h in range(hidden_dim):
input_val = input[batch_idx, seq_idx, h]
normalized = (input_val - mean_val) * inv_std * rebind[
Scalar[dtype]
](ln_weight[h]) + rebind[Scalar[dtype]](ln_bias[h])
acc += rebind[Scalar[dtype]](normalized * linear_weight[out_idx, h])
output[batch_idx, seq_idx, out_idx] = acc + rebind[Scalar[dtype]](
linear_bias[out_idx]
)
퓨전 구현은 연산들을 효율적으로 결합합니다:
-
스레드 구성:
- 시퀀스 위치당 하나의 스레드 블록 (그리드:
[batch_size, seq_len]) - 시퀀스 위치당 단일 스레드
- 스레드 인덱스:
batch_idx = block_idx.x,seq_idx = block_idx.y
- 시퀀스 위치당 하나의 스레드 블록 (그리드:
-
LayerNorm 단계:
- 시퀀스 위치에 대한 합계와 제곱합 계산
- 평균 계산: \[\Large \mu = \frac{1}{H} \sum_{i=1}^{H} x_i \]
- 분산 계산: \[\Large \sigma^2 = \frac{1}{H} \sum_{i=1}^{H} (x_i - \mu)^2 \]
- 역표준편차 계산: \[\Large \text{inv_std} = \frac{1}{\sqrt{\sigma^2 + \epsilon}} \]
-
Linear 단계:
- 각 출력 차원에 대해:
- 정규화된 값 계산: \[\Large \text{normalized} = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \]
- 선형 가중치와 곱하고 누적: \[\Large \text{acc} = \sum_{h=1}^{H} \text{normalized}h \cdot W{out,h} \]
- 선형 bias 추가: \[\Large \text{output} = \text{acc} + b_{out} \]
- 결과를
output[batch_idx, seq_idx, out_idx]에 저장
- 각 출력 차원에 대해:
-
성능 최적화:
- 두 연산을 위한 단일 커널 실행
- 계산된 통계량 재사용
- 메모리 트래픽 최소화
- 중간 텐서 할당 불필요
- 효율적인 메모리 접근 패턴
이 구현은 메모리 대역폭 사용량과 커널 실행 오버헤드를 줄여 언퓨전 버전보다 더 나은 성능을 달성합니다.
커널 퓨전의 장점
이 퍼즐에서 LayerNorm + Linear 연산을 구현하는 두 가지 방식을 살펴보았습니다:
-
언퓨전 구현:
- LayerNorm과 Linear를 별도의 커널로 실행
- 구현이 간단하지만 덜 효율적
- 높은 메모리 대역폭 사용량
- 여러 번의 커널 실행
- 벤치마크 결과: 3183.57ms (GPU)
-
퓨전 구현:
- 두 연산을 결합한 단일 커널
- 더 복잡하지만 훨씬 효율적
- 메모리 대역폭 사용량 절감
- 단일 커널 실행
- 벤치마크 결과: 3116.11ms (GPU)
메모리 대역폭 최적화
-
메모리 트래픽 제거:
- 연산 간 중간 텐서 할당 불필요
- 전역 메모리 읽기/쓰기 감소
- 선형 변환을 위한 정규화된 값 재사용
- 메모리 대역폭 절감률: \[\Large \text{reduction} = \frac{\text{unfused_bandwidth} - \text{fused_bandwidth}}{\text{unfused_bandwidth}}\]
-
캐시 효율:
- L1/L2 캐시 활용도 향상
- 캐시 미스 감소
- 개선된 메모리 접근 패턴
- 더 높은 산술 강도
오버헤드 감소
-
커널 실행 최적화:
- 여러 번 대신 단일 커널 실행
- 드라이버 오버헤드 감소
- 동기화 지점 감소
- 메모리 할당 횟수 감소
-
리소스 관리:
- 연산 간 공유 메모리 재사용
- 레지스터 활용도 향상
- 스레드 점유율 개선
- GPU 활용률 향상
성능 특성
-
확장성:
- 입력 크기에 따른 성능 확장성 향상
- 메모리 대역폭 병목 감소
- GPU 리소스의 더 효율적인 활용
- 대규모 모델에서 처리량 향상
-
수치적 효율:
- 수치 안정성 유지
- 반올림 오차 감소
- 중간 결과의 정밀도 향상
- 최적화된 연산 순서
💡 핵심 통찰: 커널 퓨전은 트랜스포머 아키텍처의 LayerNorm + Linear처럼 신경망에서 자주 함께 사용되는 연산에 특히 유리합니다. 입력 크기가 크고 모델이 복잡할수록 성능 이점은 더욱 커집니다.