CFG-free Distillation: Guidance 없이 빠른 생성

CFG-free Distillation: Guidance 없이 빠른 생성
Classifier-Free Guidance의 2배 연산 비용을 제거. 단일 forward pass로 CFG 품질 달성.
TL;DR
- 문제: CFG는 조건부/무조건부 두 번의 forward pass 필요 (2배 비용)
- 해결: Distillation으로 CFG 효과를 단일 모델에 학습
- 방법: Teacher의 CFG 출력을 Student가 모방
- 결과: 동일 품질, 절반의 연산량, 더 빠른 추론
1. Classifier-Free Guidance 복습
CFG란?
조건부 생성의 품질을 높이는 핵심 기법:
$$\tilde{\epsilon}(x_t, c) = \epsilon(x_t, \varnothing) + w \cdot (\epsilon(x_t, c) - \epsilon(x_t, \varnothing))$$
- $\epsilon(x_t, c)$: 조건부 예측
- $\epsilon(x_t, \varnothing)$: 무조건부 예측
- $w$: guidance scale (보통 7.5)
CFG의 문제점
실시간 애플리케이션에서 치명적!
2. CFG Distillation 아이디어
핵심 통찰
CFG의 출력을 **직접 예측**하도록 학습하면?
Teacher (CFG 사용):
2 forward passes → CFG 결합 → 출력Student (CFG-free):
1 forward pass → 동일한 출력Distillation 목표
$$\mathcal{L} = \mathbb{E}\left[\|\epsilon_\text{student}(x_t, c) - \tilde{\epsilon}_\text{teacher}(x_t, c)\|^2\right]$$
Student가 Teacher의 CFG 결과를 모방.
3. 학습 방법
기본 알고리즘
def cfg_distillation_loss(student, teacher, x0, c, w=7.5):
# 노이즈 추가
t = torch.rand(x0.shape[0], device=x0.device)
noise = torch.randn_like(x0)
x_t = add_noise(x0, t, noise)
# Teacher: CFG 적용
with torch.no_grad():
eps_cond = teacher(x_t, t, c)
eps_uncond = teacher(x_t, t, null_cond)
eps_cfg = eps_uncond + w * (eps_cond - eps_uncond)
# Student: 단일 예측
eps_student = student(x_t, t, c)
return F.mse_loss(eps_student, eps_cfg)Guidance Scale 조건화
다양한 guidance scale을 지원하려면:
def cfg_distillation_with_scale(student, teacher, x0, c):
# 랜덤 guidance scale 샘플링
w = torch.rand(x0.shape[0], device=x0.device) * 10 + 1 # [1, 11]
# Teacher CFG
with torch.no_grad():
eps_cfg = compute_cfg(teacher, x_t, t, c, w)
# Student: scale도 입력으로
eps_student = student(x_t, t, c, w)
return F.mse_loss(eps_student, eps_cfg)이렇게 하면 추론 시 guidance scale 조절 가능!
4. 아키텍처 수정
Guidance Scale 임베딩
class CFGFreeUNet(nn.Module):
def __init__(self, base_unet):
super().__init__()
self.unet = base_unet
# Guidance scale 임베딩
self.w_embed = nn.Sequential(
nn.Linear(1, 256),
nn.SiLU(),
nn.Linear(256, 256)
)
def forward(self, x, t, c, w):
# w를 시간 임베딩에 추가
t_emb = self.time_embed(t)
w_emb = self.w_embed(w.unsqueeze(-1))
combined_emb = t_emb + w_emb
return self.unet(x, combined_emb, c)또는 간단한 방식
고정 guidance scale만 사용한다면:
- 아키텍처 수정 불필요
- 특정 w 값으로만 distillation
5. Progressive Distillation과 결합
CFG + Step Distillation
SDXL-Turbo, SDXL-Lightning 등의 접근:
def combined_distillation(student, teacher, x0, c):
# 1. Step distillation (2 steps → 1 step)
t = sample_timestep()
t_mid = t / 2
x_t = add_noise(x0, t)
# Teacher: 2 steps with CFG
with torch.no_grad():
# Step 1
eps1 = compute_cfg(teacher, x_t, t, c, w=7.5)
x_mid = denoise_step(x_t, eps1, t, t_mid)
# Step 2
eps2 = compute_cfg(teacher, x_mid, t_mid, c, w=7.5)
x_target = denoise_step(x_mid, eps2, t_mid, 0)
# Student: 1 step, no CFG
x_pred = student.denoise(x_t, t, c)
return F.mse_loss(x_pred, x_target)최종 목표
100배 속도 향상!
6. 구현 예제
간단한 CFG-free Distillation
class CFGFreeDistillation:
def __init__(self, student, teacher, guidance_scale=7.5):
self.student = student
self.teacher = teacher
self.w = guidance_scale
# Teacher는 학습하지 않음
for p in teacher.parameters():
p.requires_grad = False
def compute_teacher_cfg(self, x_t, t, c):
"""Teacher의 CFG 출력 계산"""
# 조건부 예측
eps_cond = self.teacher(x_t, t, c)
# 무조건부 예측 (null condition)
null_c = torch.zeros_like(c)
eps_uncond = self.teacher(x_t, t, null_c)
# CFG 결합
return eps_uncond + self.w * (eps_cond - eps_uncond)
def loss(self, x0, c):
B = x0.shape[0]
# 노이즈 샘플링
t = torch.rand(B, device=x0.device)
noise = torch.randn_like(x0)
# Forward diffusion
sigma = t.view(B, 1, 1, 1)
x_t = x0 + sigma * noise
# Teacher CFG target
with torch.no_grad():
target = self.compute_teacher_cfg(x_t, t, c)
# Student prediction
pred = self.student(x_t, t, c)
return F.mse_loss(pred, target)학습 루프
def train_cfg_free(student, teacher, dataloader, epochs=100):
distiller = CFGFreeDistillation(student, teacher)
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-4)
for epoch in range(epochs):
for images, captions in dataloader:
# 텍스트 인코딩
c = text_encoder(captions)
loss = distiller.loss(images, c)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}: loss = {loss.item():.4f}")7. Variational Score Distillation
문제점
단순 distillation의 한계:
- Mode collapse 가능성
- Teacher의 한계를 그대로 상속
- 다양성 감소
VSD 접근
Score Distillation Sampling을 활용:
$$\nabla_\theta \mathcal{L}_\text{VSD} = \mathbb{E}\left[w(t)(\epsilon_\text{teacher} - \epsilon_\text{student}) \frac{\partial \epsilon_\text{student}}{\partial \theta}\right]$$
Adversarial Distillation
GAN loss 추가:
def adversarial_distillation_loss(student, teacher, discriminator, x0, c):
# Distillation loss
dist_loss = cfg_distillation_loss(student, teacher, x0, c)
# Generate sample
z = torch.randn_like(x0)
x_gen = student.sample(z, c)
# Adversarial loss
adv_loss = -discriminator(x_gen, c).mean()
return dist_loss + 0.1 * adv_loss8. 실제 모델들
SDXL-Turbo
Stability AI의 접근:
- Adversarial Diffusion Distillation (ADD)
- CFG-free + 1-4 step generation
- GAN discriminator 사용
SDXL-Lightning
ByteDance의 접근:
- Progressive distillation
- CFG distillation
- LoRA 기반 효율적 학습
LCM (Latent Consistency Models)
Consistency distillation + CFG:
- Consistency loss로 step 감소
- CFG 효과 내재화
9. 품질 비교
정량적 결과
거의 동등한 품질, 압도적 속도 향상!
속도 비교 (A100 기준)
40배 빠름!
10. 한계와 미래
현재 한계
미래 방향
- Self-Distillation: Teacher 없이 자체 개선
- Continuous Guidance: 임의의 w 값 지원
- Multi-Modal Guidance: 여러 조건 동시 처리
결론
CFG-free Distillation의 핵심:
- Teacher의 CFG 효과를 Student에 증류
- 2배 연산량 제거
- Step distillation과 결합하면 수십 배 속도 향상
- 실시간 이미지 생성 가능
References
- Sauer, A., et al. "Adversarial Diffusion Distillation" (2023)
- Lin, S., et al. "SDXL-Lightning" (2024)
- Luo, S., et al. "Latent Consistency Models" (2023)
- Ho, J., Salimans, T. "Classifier-Free Diffusion Guidance" (2022)
- Meng, C., et al. "On Distillation of Guided Diffusion Models" (CVPR 2023)