AI EngineeringEN

Flash Attention vs Sparse Attention — LLM 추론 속도를 가르는 핵심 기술

Flash Attention과 Sparse Attention의 원리부터 실전 벤치마크까지. DSA, DMS, Sliding Window 비교. 언제 뭘 쓸지 의사결정 매트릭스 포함.

Flash Attention vs Sparse Attention — LLM 추론 속도를 가르는 핵심 기술

Flash Attention vs Sparse Attention — LLM 추론 속도를 가르는 핵심 기술

AI 에이전트가 코드 저장소 전체를 분석하고, 수십만 토큰의 대화 기록을 처리하는 시대입니다. 컨텍스트가 길어질수록 Attention 연산의 $O(n^2)$ 비용이 치명적인 병목이 됩니다.

이 글에서는 이 병목을 해결하는 두 가지 핵심 기술 — Flash AttentionSparse Attention — 을 원리부터 실전 벤치마크까지 비교합니다.

문제: Attention의 $O(n^2)$ 벽

표준 Self-Attention은 모든 토큰 쌍의 관계를 계산합니다. 시퀀스 길이가 $n$이면:

  • 연산량: $O(n^2 \cdot d)$ — 토큰 수의 제곱에 비례
  • 메모리: $O(n^2)$ — Attention Score 행렬 전체를 메모리에 올려야 함
시퀀스 길이Attention 행렬 크기fp16 메모리
2K4M8 MB
8K64M128 MB
32K1B2 GB
128K16B32 GB

128K 컨텍스트에서 Attention Score 행렬만 32 GB입니다. 이건 head 하나에 대한 수치이고, multi-head까지 고려하면 현실적으로 불가능한 메모리입니다.

Flash Attention: 하드웨어 최적화

Flash Attention은 같은 수학적 결과를 내면서 메모리를 극적으로 줄이는 기법입니다. 핵심 아이디어는 GPU의 메모리 계층 구조를 활용하는 것입니다.

원리: Tiling + Online Softmax

표준 Attention은 $n \times n$ 크기의 Attention Score 행렬을 GPU의 HBM(High Bandwidth Memory)에 통째로 올립니다. Flash Attention은 이 행렬을 작은 타일(tile)로 나눠서 GPU의 SRAM(on-chip 고속 메모리)에서 처리합니다.

[표준 Attention]
1. Q × Kᵀ → Attention Score (n×n) 전체를 HBM에 저장
2. Softmax → HBM에서 읽고 다시 저장
3. Score × V → HBM에서 읽고 결과 저장
→ HBM 접근: 3회 (느림)

[Flash Attention]
1. Q, K, V를 타일 단위로 SRAM에 로드
2. 타일 내에서 Score 계산 + Softmax + V 곱셈을 한 번에 수행
3. 최종 결과만 HBM에 기록
→ HBM 접근: 1회 (빠름)

중요한 점: 연산량 자체는 줄어들지 않습니다. 여전히 $O(n^2)$입니다. 하지만 메모리 접근 횟수가 줄어들면서 실제 속도는 2~4배 빨라집니다.

Flash Attention 버전별 진화

버전연도핵심 개선속도 향상
Flash Attention 12022Tiling + Kernel Fusion표준 대비 2~4x
Flash Attention 22023병렬화 개선, 비대칭 Q/KV splitFA1 대비 ~2x
Flash Attention 32024Hopper GPU 최적화, FP8 지원FA2 대비 ~1.5x

코드에서 사용하기

python
from transformers import AutoModelForCausalLM

# Flash Attention 2는 transformers에 기본 통합
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="flash_attention_2",  # 이 한 줄만 추가
)
python
# 또는 직접 사용
from flash_attn import flash_attn_func

# (batch, seq_len, num_heads, head_dim)
output = flash_attn_func(q, k, v, causal=True)

Flash Attention의 한계

  • 연산량은 그대로: 여전히 $O(n^2)$ — 128K+ 컨텍스트에서는 근본적인 해결이 안 됩니다
  • GPU 의존: CUDA가 필요하고, Hopper/Ampere 이상에서 최적 성능
  • 메모리 절약은 Attention Score에만: KV Cache 자체는 줄지 않습니다

Sliding Window Attention: 가장 단순한 Sparse

Flash Attention이 "같은 계산을 빠르게"라면, Sparse Attention은 "계산 자체를 줄이자"는 접근입니다.

가장 간단한 형태가 Sliding Window Attention입니다. 고정 크기 윈도우 내의 토큰만 참조합니다.

[Full Attention — 모든 토큰 참조]
Token 8: [1][2][3][4][5][6][7][8] ← 8개 참조

[Sliding Window (W=4)]
Token 8: [_][_][_][_][5][6][7][8] ← 4개만 참조
  • 장점: 연산이 $O(n \cdot w)$으로 줄어듭니다 ($w$=윈도우 크기)
  • 치명적 단점: 윈도우 밖의 초기 프롬프트를 완전히 잊어버립니다

Mistral 7B가 Sliding Window(4096)를 사용하는데, 초기 시스템 프롬프트가 4K 이전에 있으면 참조할 수 없습니다. 에이전트 워크플로우에서 이런 "기억상실"은 치명적입니다.

Sparse Attention: 똑똑하게 골라서 보기

Sparse Attention은 Sliding Window의 기억상실 문제를 해결합니다. 모든 토큰을 보는 것도 아니고, 최근 토큰만 보는 것도 아니라 — 가장 중요한 토큰만 동적으로 선별합니다.

DeepSeek Sparse Attention (DSA)

DeepSeek-V3.2부터 적용된 방법입니다. 2단계 파이프라인으로 동작합니다:

Stage 1 — Lightning Indexer (경량 스캐너)

  • 전체 컨텍스트를 빠르게 훑으며 각 토큰의 "중요도 점수"를 계산합니다
  • 연산량: 전체 Attention의 ~10% 수준
  • 결과: 상위 K개 토큰의 인덱스 리스트

Stage 2 — Selective Attention (정밀 계산)

  • Stage 1에서 선별된 토큰만 대상으로 Full Attention을 수행합니다
  • 전체 컨텍스트의 5~20%만 실제 계산
[Full Attention]
Current Token → 모든 128K 토큰과 계산 (128K ops)

[DSA]
Current Token → Indexer가 6K 토큰 선별 (13K ops)
             → 선별된 6K 토큰과 Full Attention (6K ops)
             → 총 ~19K ops (85% 절감)

IndexCache: DSA의 업그레이드

Z.ai 연구팀이 발표한 IndexCache는 DSA의 Indexer를 더 효율화합니다. 핵심 관찰: 인접한 레이어들이 거의 같은 토큰을 중요하다고 판단합니다.

Layer 15에서 "중요"하다고 선별된 토큰은 Layer 16에서도 대부분 중요합니다. 그래서 Indexer 결과를 인접 레이어끼리 공유합니다.

결과:

  • Indexer 연산량 75% 감소
  • 전체 추론 속도 1.82x 향상
  • 품질 손실은 거의 없음

Nvidia Dynamic Memory Sparsification (DMS)

Nvidia의 접근법은 다릅니다. 기존 모델을 수정하지 않고, 사후 학습(post-training)으로 토큰을 버리는 법을 학습시킵니다.

핵심 차이점 — Delayed Eviction (지연 제거):

[즉시 제거]
Token 중요도 < 임계값 → 바로 삭제
→ 문제: 나중에 필요해질 수 있는 토큰도 삭제

[DMS — 지연 제거]
Token 중요도 < 임계값 → "대기열"에 추가
→ 일정 시간 후에도 참조되지 않으면 삭제
→ 가비지 컬렉션처럼 동작

DMS 결과:

  • 일부 모델에서 추론 비용 8x 절감
  • 정확도 손실 없음
  • 기존 모델에 사후 적용 가능 (재학습 불필요)

벤치마크 비교

각 방법의 성능을 정리합니다. 128K 컨텍스트, 단일 요청 기준입니다.

방법연산 복잡도메모리 절감TTFT 개선품질 영향적용 난이도
표준 Attention$O(n^2)$기준기준없음-
Flash Attention 2$O(n^2)$Attention Score 없앰2~4x없음한 줄 설정
Sliding Window$O(n \cdot w)$캐시 크기 고정3~5x기억상실모델 내장
DeepSeek DSA$O(n \cdot k)$5~20%만 계산4~8x미미모델 내장
Nvidia DMS동적최대 8x 절감2~8x미미Post-training

의사결정 매트릭스

짧은 컨텍스트 (< 8K)

  • Flash Attention 2 + 표준 KV Cache로 충분합니다
  • Sparse Attention의 이점이 크지 않습니다

중간 컨텍스트 (8K ~ 32K)

  • Flash Attention 2가 여전히 효과적입니다
  • 배치 처리 시 GQA 모델(Qwen 2.5, Llama 3)을 선택하는 것이 KV Cache 관리에 유리합니다

긴 컨텍스트 (32K+)

  • Sparse Attention이 필수입니다
  • DSA 내장 모델(DeepSeek-V3.2+)을 선택하거나
  • DMS로 기존 모델을 최적화하는 방향을 고려합니다

Needle-in-a-Haystack 작업

  • Sparse Attention의 약점 영역입니다
  • 특정 정보 검색이 중요하면 KV Cache 압축(KVTC 등)이 더 적합합니다

실전: Flash Attention 활성화 확인

현재 사용 중인 모델이 Flash Attention을 제대로 쓰고 있는지 확인하는 방법입니다.

python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-3.1-8B-Instruct"

# Flash Attention 2 활성화
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="flash_attention_2",
)

# Attention 구현 확인
attn_layer = model.model.layers[0].self_attn
print(f"Attention class: {attn_layer.__class__.__name__}")
# → LlamaFlashAttention2 (Flash Attention 활성화됨)
# → LlamaSdpaAttention (PyTorch SDPA — Flash Attention 미지원 시 폴백)
# → LlamaAttention (표준 구현)

# 간단한 속도 비교
import time

prompt = "Write a comprehensive analysis of " * 100  # 긴 프롬프트
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(model.device)

# Warmup
with torch.no_grad():
    model(**inputs)

# 측정
torch.cuda.synchronize()
start = time.perf_counter()

with torch.no_grad():
    for _ in range(10):
        model(**inputs)

torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / 10
print(f"Average forward pass: {elapsed*1000:.1f} ms")
Premium Series4 parts

LLM 추론 최적화 시리즈 — Attention부터 프로덕션 서빙까지

Sparse Attention, KV Cache 압축, PagedAttention 등을 코드로 직접 구현하고 벤치마크합니다. vLLM/TGI 프로덕션 배포까지.

정리

기술핵심 원리추천 상황
Flash AttentionHW 최적화 (Tiling)모든 상황의 기본값
Sliding Window고정 윈도우짧은 컨텍스트, 스트리밍
DSA2단계 선별 Attention긴 컨텍스트 추론
DMS학습 기반 토큰 제거기존 모델 최적화

Flash Attention은 기본값으로 항상 켜두되, 32K+ 컨텍스트를 다룬다면 Sparse Attention 지원 모델을 선택하는 것이 현재 가장 실용적인 가이드라인입니다.

🔒

이어서 읽으려면 로그인이 필요합니다

무료 회원가입으로 전체 콘텐츠를 확인하세요.

관련 포스트