TransformerLens 실전: Activation Patching으로 모델 회로를 읽다

TransformerLens 실전: Activation Patching으로 모델 회로를 읽다
지난 글에서 Lens는 모델의 중간 사고를 읽는 창이라고 했습니다.
하지만 "읽기"만으로는 핵심 질문에 답할 수 없습니다:
모델이 그 정보를 정말로 '사용'하고 있는가?
어떤 layer의 hidden state에 "Paris"라는 정보가 있다고 해서, 그 layer가 최종 답을 만드는 데 실제로 기여하는지는 알 수 없습니다. 정보가 있되 사용되지 않을 수도 있기 때문입니다.
이를 확인하려면 visualization이 아니라 causal intervention이 필요합니다. 모델의 내부를 직접 조작하고, 출력이 어떻게 변하는지 관찰하는 것입니다.
1. TransformerLens: Interpretability의 수술 도구
TransformerLens는 Neel Nanda가 만든 mechanistic interpretability 라이브러리입니다. 핵심 기능은 Transformer의 모든 내부 activation에 hook을 걸어 읽고, 수정하고, 교체할 수 있다는 것입니다.
pip install transformer_lensHookedTransformer: Hook이 달린 모델
TransformerLens의 핵심 클래스는 HookedTransformer입니다. 일반 Transformer와 동일하게 동작하되, 모든 중요한 activation 지점에 HookPoint가 삽입되어 있습니다.
from transformer_lens import HookedTransformer
model = HookedTransformer.from_pretrained("gpt2-small")Hook Point Map: 모든 관측 가능 지점
GPT-2 Small 기준으로, 각 layer마다 다음 hook point들이 있습니다:
Residual Stream:
blocks.{l}.hook_resid_pre— block 입력blocks.{l}.hook_resid_mid— attention 후, MLP 전blocks.{l}.hook_resid_post— block 출력
Attention:
blocks.{l}.attn.hook_q— Query[batch, pos, n_heads, d_head]blocks.{l}.attn.hook_k— Keyblocks.{l}.attn.hook_v— Valueblocks.{l}.attn.hook_pattern— Attention pattern (softmax 후)blocks.{l}.attn.hook_result— Head별 출력
MLP:
blocks.{l}.mlp.hook_pre— activation 전blocks.{l}.mlp.hook_post— activation 후
GPT-2 Small (12 layers, 12 heads)의 경우, 총 100개 이상의 hook point가 있습니다.
ActivationCache: 한 번에 모든 activation 저장
run_with_cache()를 호출하면, 한 번의 forward pass로 모든 hook point의 activation을 저장할 수 있습니다.
prompt = "When John and Mary went to the store, John gave a drink to"
tokens = model.to_tokens(prompt)
logits, cache = model.run_with_cache(tokens)
# cache에서 원하는 activation에 접근
resid = cache["blocks.5.hook_resid_post"] # Layer 5의 residual
attn_pattern = cache["blocks.8.attn.hook_pattern"] # Layer 8의 attention patternActivationCache는 단순한 dictionary가 아니라, 분석에 유용한 메서드를 제공합니다:
cache.decompose_resid(layer)— residual stream을 component별로 분해cache.accumulated_resid(layer)— 누적 residual (Logit Lens용)cache.logit_attrs(direction)— 특정 토큰 방향으로의 기여도 계산cache.stack_head_results(layer)— attention output을 head별로 분리
2. Activation Patching: 인과 추적의 핵심
왜 Patching이 필요한가
Logit Lens는 각 layer에서 "Paris"가 top prediction인지 보여줍니다. 하지만 이것은 상관관계(correlation)일 뿐입니다. 실제로 중요한 것은 인과관계(causation)입니다:
"이 layer의 activation을 바꾸면, 모델의 최종 답이 변하는가?"
Activation Patching의 알고리즘
Activation patching(causal tracing이라고도 불림)은 3단계로 구성됩니다:
Step 1: Clean Run (정상 실행)
모델에 "정답을 맞추는" prompt를 넣고, 모든 activation을 캐싱합니다.
Clean: "When John and Mary went to the store, John gave a drink to" → " Mary" (정답)
Step 2: Corrupted Run (손상된 실행)
약간 변형된 prompt를 넣어 모델이 "오답을 내도록" 만듭니다. 이것이 baseline입니다.
Corrupted: "When John and Mary went to the store, Mary gave a drink to" → " John" (오답)
여기서 corruption은 단순히 이름을 바꾼 것입니다. 이로 인해 모델은 반대 이름을 예측하게 됩니다.
Step 3: Patched Run (패칭 실행)
Corrupted prompt를 다시 넣되, 특정 위치의 activation만 clean run의 것으로 교체합니다. 그리고 출력이 복원되는지 관찰합니다.
Corrupted input → 모델 실행 중 layer 8의 activation을 clean 것으로 교체 → 출력이 "Mary"로 복원되는가?
만약 복원된다면, 그 layer의 activation이 정답을 만드는 데 인과적으로 중요하다는 의미입니다.

패칭 메트릭: Logit Difference
출력의 변화를 측정하기 위해 logit difference를 사용합니다:
def get_logit_diff(logits, correct_token, incorrect_token):
"""정답 토큰과 오답 토큰의 logit 차이"""
return logits[0, -1, correct_token] - logits[0, -1, incorrect_token]이를 정규화하면:
$$\text{normalized\_metric} = \frac{\text{patched\_diff} - \text{corrupted\_diff}}{\text{clean\_diff} - \text{corrupted\_diff}}$$
- 0 = corrupted와 동일 (복원 없음)
- 1 = clean과 동일 (완전 복원)
3. 실전: IOI Task에서의 Activation Patching
IOI (Indirect Object Identification) Task
IOI task는 activation patching의 표준 벤치마크입니다. 다음과 같은 문장에서 간접 목적어를 맞추는 task입니다:
"When John and Mary went to the store, John gave a drink to ___" → " Mary"
이 task가 좋은 이유는 다음과 같습니다:
- 정답이 명확함 (문맥에서 유일하게 결정됨)
- Clean/corrupted 쌍을 만들기 쉬움 (이름 교체)
- GPT-2 Small이 높은 정확도로 맞추므로 회로 분석이 가능
Step-by-Step 코드
import torch
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name
model = HookedTransformer.from_pretrained("gpt2-small")
model.set_use_attn_result(True) # head별 출력 접근 활성화
# 1. Prompt 정의
clean_prompt = "When John and Mary went to the store, John gave a drink to"
corrupted_prompt = "When John and Mary went to the store, Mary gave a drink to"
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)
# 정답/오답 토큰
mary_token = model.to_single_token(" Mary")
john_token = model.to_single_token(" John")
# 2. Clean & Corrupted run
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
clean_diff = (clean_logits[0, -1, mary_token] - clean_logits[0, -1, john_token]).item()
corrupted_diff = (corrupted_logits[0, -1, mary_token] - corrupted_logits[0, -1, john_token]).item()
print(f"Clean logit diff: {clean_diff:.2f}") # 양수: Mary를 더 선호
print(f"Corrupted logit diff: {corrupted_diff:.2f}") # 음수: John을 더 선호
# 3. Layer별 Residual Stream Patching
results = torch.zeros(model.cfg.n_layers, clean_tokens.shape[1])
for layer in range(model.cfg.n_layers):
for pos in range(clean_tokens.shape[1]):
hook_name = get_act_name("resid_pre", layer)
def patch_hook(activation, hook, pos=pos):
activation[:, pos, :] = clean_cache[hook.name][:, pos, :]
return activation
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(hook_name, patch_hook)]
)
patched_diff = (patched_logits[0, -1, mary_token] - patched_logits[0, -1, john_token]).item()
results[layer, pos] = (patched_diff - corrupted_diff) / (clean_diff - corrupted_diff)결과 해석
Residual stream patching 결과를 heatmap으로 시각화하면 (y축: layer, x축: token position, 색상: 복원 정도) 다음과 같은 패턴이 나타납니다:
- 초기 layer (L0-L4) + 두 번째 "John" (subject) 위치: 가장 높은 복원 효과. 이 단계에서 모델은 "방금 나온 이름이 John"이라는 핵심 정보를 residual stream에 저장합니다. 이 정보를 복원하면 정답이 돌아옵니다.
- 후기 layer (L9-L11) + 마지막 토큰 "to" 위치: 또 다른 핵심 영역. 추출된 정보를 바탕으로 최종 답 "Mary"를 결정하여 출력하는 구간으로, Name Mover Head들이 작동하는 지점입니다.
- 첫 번째 "John"과 "Mary" 위치: 거의 효과 없음. 문장 초반 등장("When John and Mary went to...")은 circuit에 직접 기여하지 않습니다. 효과는 오직 두 번째 "John"(subject)과 마지막 "to"(출력 지점)에 집중됩니다.

이것은 Logit Lens가 제공하는 것과 근본적으로 다른 종류의 증거입니다. Logit Lens는 Layer 8에서 "Paris"가 보인다고 알려주지만, 이는 관찰에 불과합니다. Activation patching은 해당 layer가 *중요하다는 것*을 증명합니다: 이것을 깨뜨리면, 답도 깨집니다.
4. Head별 Patching
Layer에서 개별 Head로
Residual stream patching은 어떤 layer가 중요한지 알려줍니다. 하지만 각 layer 안에는 12개의 attention head와 MLP block이 있습니다. 어떤 head가 실제로 일하고 있는 걸까요? 이를 알아내기 위해 개별 attention head의 출력을 패칭합니다:
# 각 layer의 각 head 출력을 패칭
head_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
for head in range(model.cfg.n_heads):
hook_name = get_act_name("result", layer)
def head_patch_hook(activation, hook, head=head):
activation[:, :, head, :] = clean_cache[hook.name][:, :, head, :]
return activation
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(hook_name, head_patch_hook)]
)
patched_diff = (patched_logits[0, -1, mary_token] - patched_logits[0, -1, john_token]).item()
head_results[layer, head] = (patched_diff - corrupted_diff) / (clean_diff - corrupted_diff)get_act_name("result", layer)는 "blocks.{layer}.attn.hook_result"를 반환하며, shape은 [batch, pos, n_heads, d_head]입니다. Hook은 단일 head의 출력만 clean 버전으로 교체합니다.
Heatmap 해석
결과를 heatmap으로 시각화하면 (y축: layer, x축: head, 색상: 복원 정도) 다음 패턴이 보입니다:

- 강한 양수 (파란색): L8.H6, L9.H9, L8.H10 등이 가장 강한 양수를 보입니다. 이 중 L9.H9는 Name Mover Head로, 정답 이름을 출력 position에 직접 쓰는 역할이고, L8.H6과 L8.H10은 S-Inhibition Head로, subject 이름(오답)을 억제하여 정답이 선택되도록 돕습니다.
- 강한 음수 (빨간색): L10.H7, L11.H10 등. 이 head들을 복원하면 오히려 성능이 나빠집니다. Negative Name Mover Head로, 오답 이름(subject)을 출력에 쓰려는 head입니다. Clean run에서 이 head들이 "John"(오답)을 쓰고 있었기 때문에, 이를 복원하면 corrupted context에서 정답 복원이 방해됩니다.
- L0 행이 거의 비어 있음: Duplicate Token Head(L0.H1, L0.H10)는 초기 layer에서 "이 이름이 이미 나왔다"는 정보를 감지하지만, 최종 출력에 직접 영향을 주지는 않습니다. 후기 head들에게 정보를 전달하는 간접적 기여이므로 직접 패칭 heatmap에는 나타나지 않습니다. 이런 간접 경로는 Section 6의 Path Patching으로 추적합니다.
핵심 발견: IOI Circuit
Head별 patching과 추가 분석(path patching 등)을 종합하면, Wang et al. (2022)가 밝힌 IOI circuit의 전체 구조가 드러납니다:
이 head들은 명확한 알고리즘 해석이 가능한 연결된 circuit을 형성합니다:
$$\text{Duplicate Token Heads (L0)} \rightarrow \text{S-Inhibition Heads (L7-8)} \rightarrow \text{Name Mover Heads (L9-10)} \rightarrow \text{Output}$$
쉽게 말하면, 이 circuit은 다음과 같은 알고리즘을 구현하고 있습니다:
- 중복 감지: 초기 head들이 한 이름이 두 번 나타난 것을 감지합니다 (setup절에 한 번, 행동의 주어로 한 번).
- Subject 억제: 중간 layer head들이 중복된 이름(subject)을 억제하여 예측 확률을 낮춥니다.
- 다른 이름 이동: 후기 layer head들이 *나머지* 이름(간접 목적어)을 출력 위치에 복사합니다.
이것은 모델 내부에 대한 인과적 개입만으로 발견된 진정한 알고리즘입니다. Wang et al. (2022)의 IOI 논문에서 처음 완전하게 기술되었습니다.
아래는 Name Mover Head인 L9.H9의 attention pattern입니다. 마지막 토큰("to")이 간접 목적어("Mary")에 강하게 attend하여, 그 이름을 출력에 직접 복사하는 것을 확인할 수 있습니다.

5. TransformerLens의 내장 Patching 함수
위에서는 직접 loop를 돌며 patching했지만, TransformerLens는 내장 patching 함수를 제공합니다:
from transformer_lens.patching import generic_activation_patch
def metric_fn(logits):
diff = logits[0, -1, mary_token] - logits[0, -1, john_token]
return (diff - corrupted_diff) / (clean_diff - corrupted_diff)
# Residual stream patching (layer x position)
result = generic_activation_patch(
model=model,
corrupted_tokens=corrupted_tokens,
clean_cache=clean_cache,
patching_metric=metric_fn,
patch_setter=layer_pos_patch_setter,
activation_name="resid_pre",
index_axis_names=["layer", "pos"],
)
# result.shape: [n_layers, seq_len]내장 patch setter들은 다음과 같습니다:
6. Patching을 넘어서: Path Patching
일반 activation patching은 "이 layer의 이 component가 중요한가?"를 알려줍니다. 하지만 더 세밀한 질문도 할 수 있습니다:
"Head A의 출력이 Head B의 Query로 들어갈 때, 이 특정 경로(path)가 중요한가?"
이것이 path patching입니다. Q, K, V 입력을 분리하여 패칭하면, head 간의 연결 관계를 추적할 수 있습니다.
model.set_use_split_qkv_input(True) # Q, K, V 입력 분리 활성화
# Layer 9 Head 6의 Query 입력만 패칭
hook_name = "blocks.9.attn.hook_q_input"
def q_patch_hook(activation, hook):
activation[:, :, 6, :] = clean_cache[hook.name][:, :, 6, :]
return activationPath patching을 통해 Section 4에서 발견한 IOI circuit의 head 간 연결 구조를 구체적으로 검증할 수 있습니다. 예를 들어, S-Inhibition Head의 Query 입력을 패칭하면 Duplicate Token Head → S-Inhibition Head로의 정보 흐름을 확인할 수 있습니다.
7. 실전 팁
GPU 필수: 패칭은 component마다 forward pass를 한 번씩 실행합니다. Layer x Position 패칭은 수천 번의 forward pass가 필요하므로 GPU 없이는 비현실적입니다.
점진적 접근: resid_pre 패칭으로 중요 layer를 찾고 → head 패칭으로 중요 head를 찾고 → path 패칭으로 head 간 연결을 추적하는 순서를 추천합니다.
배치 처리: 하나의 prompt에서의 결과는 noisy할 수 있습니다. 8-16개의 유사한 prompt를 만들어 평균을 내면 안정적인 결과를 얻습니다.
`model.set_use_attn_result(True)`: head별 출력 패칭을 하려면 반드시 이 설정을 활성화해야 합니다.
Wrap-up
TransformerLens와 activation patching은 interpretability를 상관 분석에서 인과 분석으로 끌어올렸습니다.
하지만 여전히 한 가지 근본적인 한계가 있습니다: activation은 dense하고 polysemantic합니다. 하나의 뉴런이 여러 가지 개념에 동시에 반응하기 때문에, 뉴런 단위로는 깨끗한 해석이 어렵습니다.
다음 글에서는 이 문제를 해결하는 최신 접근법, Sparse Autoencoder(SAE)와 TensorLens를 다룹니다.
References
- TransformerLens GitHub
https://github.com/TransformerLensOrg/TransformerLens
- Neel Nanda. *Activation Patching in TransformerLens*
https://github.com/TransformerLensOrg/TransformerLens/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb
- Meng et al. *Locating and Editing Factual Associations in GPT (ROME)* (2022)
https://arxiv.org/abs/2202.05262
- Wang et al. *Interpretability in the Wild: A Circuit for Indirect Object Identification* (2022)
https://arxiv.org/abs/2211.00593
- Neel Nanda. *Mechanistic Interpretability Intro*
https://www.neelnanda.io/mechanistic-interpretability