Puzzle 5: 브로드캐스트

개요

1D TileTensor ab를 브로드캐스트로 더해 2D TileTensor output에 저장하는 커널을 구현해 보세요.

참고: 스레드 수가 행렬의 위치 수보다 많습니다.

Broadcast 시각화 Broadcast 시각화

핵심 개념

이 퍼즐에서 배울 내용:

  • 브로드캐스트 연산에 TileTensor 사용하기
  • 서로 다른 텐서 크기 다루기
  • TileTensor로 2D 인덱싱 처리하기

핵심은 TileTensor가 서로 다른 텐서 크기 \((1, n)\)와 \((n, 1)\)을 \((n,n)\)으로 자연스럽게 브로드캐스트할 수 있다는 점입니다. 그러면서도 경계 검사는 여전히 필요합니다.

  • 텐서 크기: 입력 벡터의 크기는 \((1, n)\)과 \((n, 1)\)
  • 브로드캐스트: 두 차원을 결합해 \((n,n)\) 출력 생성
  • 가드 조건: 출력 크기에 대한 경계 검사는 여전히 필요
  • 스레드 범위: 텐서 원소 \((2 \times 2)\)보다 스레드 \((3 \times 3)\)가 많음

완성할 코드

comptime SIZE = 2
comptime BLOCKS_PER_GRID = 1
comptime THREADS_PER_BLOCK = (3, 3)
comptime dtype = DType.float32
comptime out_layout = row_major[SIZE, SIZE]()
comptime a_layout = row_major[1, SIZE]()
comptime b_layout = row_major[SIZE, 1]()
comptime OutLayout = type_of(out_layout)
comptime ALayout = type_of(a_layout)
comptime BLayout = type_of(b_layout)


def broadcast_add(
    output: TileTensor[mut=True, dtype, OutLayout, MutAnyOrigin],
    a: TileTensor[mut=False, dtype, ALayout, ImmutAnyOrigin],
    b: TileTensor[mut=False, dtype, BLayout, ImmutAnyOrigin],
    size: Int,
):
    var row = thread_idx.y
    var col = thread_idx.x
    # FILL ME IN (roughly 2 lines)


전체 코드 보기: problems/p05/p05.mojo

  1. 2D 인덱스 가져오기: row = thread_idx.y, col = thread_idx.x
  2. 가드 추가: if row < size and col < size
  3. 가드 내부: TileTensor로 ab 값을 어떻게 브로드캐스트할지 생각해 보세요

코드 실행

솔루션을 테스트하려면 터미널에서 다음 명령어를 실행하세요:

pixi run p05
pixi run -e amd p05
pixi run -e apple p05
uv run poe p05

퍼즐을 아직 풀지 않았다면 출력이 다음과 같이 나타납니다:

out: HostBuffer([0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([1.0, 2.0, 11.0, 12.0])

솔루션

def broadcast_add(
    output: TileTensor[mut=True, dtype, OutLayout, MutAnyOrigin],
    a: TileTensor[mut=False, dtype, ALayout, ImmutAnyOrigin],
    b: TileTensor[mut=False, dtype, BLayout, ImmutAnyOrigin],
    size: Int,
):
    var row = thread_idx.y
    var col = thread_idx.x
    if row < size and col < size:
        output[row, col] = a[0, col] + b[row, 0]


TileTensor 브로드캐스트와 GPU 스레드 매핑의 핵심 개념을 보여주는 솔루션입니다:

  1. 스레드에서 행렬로 매핑

    • thread_idx.y로 행, thread_idx.x로 열에 접근
    • 자연스러운 2D 인덱싱이 출력 행렬 구조와 일치
    • 초과 스레드(3×3 그리드)는 경계 검사로 처리
  2. 브로드캐스트 작동 방식

    • 입력 a의 크기는 (1,n): a[0,col]이 행을 가로질러 브로드캐스트
    • 입력 b의 크기는 (n,1): b[row,0]이 열을 가로질러 브로드캐스트
    • 출력의 크기는 (n,n): 각 원소는 해당 브로드캐스트 값들의 합
    [ a0 a1 ]  +  [ b0 ]  =  [ a0+b0  a1+b0 ]
                  [ b1 ]     [ a0+b1  a1+b1 ]
    
  3. 경계 검사

    • 가드 조건 row < size and col < size로 범위 초과 접근 방지
    • 행렬 범위와 초과 스레드를 효율적으로 처리
    • 브로드캐스트 덕분에 ab에 대한 별도 검사 불필요

이 패턴은 이후 퍼즐에서 다룰 더 복잡한 텐서 연산의 기초가 됩니다.