CFG-free Distillation: Fast Generation Without Guidance

CFG-free Distillation: Fast Generation Without Guidance
Eliminating the 2x computational cost of Classifier-Free Guidance. Achieving CFG quality with a single forward pass.
TL;DR
- Problem: CFG requires two forward passes (conditional + unconditional) = 2x cost
- Solution: Distill CFG effect into a single model
- Method: Student mimics Teacher's CFG output
- Result: Same quality, half the computation, faster inference
1. Classifier-Free Guidance Review
What is CFG?
The key technique for improving conditional generation quality:
$$\tilde{\epsilon}(x_t, c) = \epsilon(x_t, \varnothing) + w \cdot (\epsilon(x_t, c) - \epsilon(x_t, \varnothing))$$
- $\epsilon(x_t, c)$: Conditional prediction
- $\epsilon(x_t, \varnothing)$: Unconditional prediction
- $w$: Guidance scale (typically 7.5)
Problems with CFG
Critical for real-time applications!
2. CFG Distillation Idea
Key Insight
What if we train a model to **directly predict** the CFG output?
Teacher (with CFG):
2 forward passes → CFG combination → outputStudent (CFG-free):
1 forward pass → same outputDistillation Objective
$$\mathcal{L} = \mathbb{E}\left[\|\epsilon_\text{student}(x_t, c) - \tilde{\epsilon}_\text{teacher}(x_t, c)\|^2\right]$$
Student mimics Teacher's CFG result.
3. Training Method
Basic Algorithm
def cfg_distillation_loss(student, teacher, x0, c, w=7.5):
# Add noise
t = torch.rand(x0.shape[0], device=x0.device)
noise = torch.randn_like(x0)
x_t = add_noise(x0, t, noise)
# Teacher: Apply CFG
with torch.no_grad():
eps_cond = teacher(x_t, t, c)
eps_uncond = teacher(x_t, t, null_cond)
eps_cfg = eps_uncond + w * (eps_cond - eps_uncond)
# Student: Single prediction
eps_student = student(x_t, t, c)
return F.mse_loss(eps_student, eps_cfg)Guidance Scale Conditioning
To support various guidance scales:
def cfg_distillation_with_scale(student, teacher, x0, c):
# Sample random guidance scale
w = torch.rand(x0.shape[0], device=x0.device) * 10 + 1 # [1, 11]
# Teacher CFG
with torch.no_grad():
eps_cfg = compute_cfg(teacher, x_t, t, c, w)
# Student: scale as input
eps_student = student(x_t, t, c, w)
return F.mse_loss(eps_student, eps_cfg)This allows adjustable guidance scale at inference!
4. Architecture Modifications
Guidance Scale Embedding
class CFGFreeUNet(nn.Module):
def __init__(self, base_unet):
super().__init__()
self.unet = base_unet
# Guidance scale embedding
self.w_embed = nn.Sequential(
nn.Linear(1, 256),
nn.SiLU(),
nn.Linear(256, 256)
)
def forward(self, x, t, c, w):
# Add w to time embedding
t_emb = self.time_embed(t)
w_emb = self.w_embed(w.unsqueeze(-1))
combined_emb = t_emb + w_emb
return self.unet(x, combined_emb, c)Or Simple Approach
If using fixed guidance scale only:
- No architecture modification needed
- Distill with specific w value
5. Combining with Progressive Distillation
CFG + Step Distillation
Approach used by SDXL-Turbo, SDXL-Lightning:
def combined_distillation(student, teacher, x0, c):
# 1. Step distillation (2 steps → 1 step)
t = sample_timestep()
t_mid = t / 2
x_t = add_noise(x0, t)
# Teacher: 2 steps with CFG
with torch.no_grad():
# Step 1
eps1 = compute_cfg(teacher, x_t, t, c, w=7.5)
x_mid = denoise_step(x_t, eps1, t, t_mid)
# Step 2
eps2 = compute_cfg(teacher, x_mid, t_mid, c, w=7.5)
x_target = denoise_step(x_mid, eps2, t_mid, 0)
# Student: 1 step, no CFG
x_pred = student.denoise(x_t, t, c)
return F.mse_loss(x_pred, x_target)Final Goal
100x speedup!
6. Implementation Example
Simple CFG-free Distillation
class CFGFreeDistillation:
def __init__(self, student, teacher, guidance_scale=7.5):
self.student = student
self.teacher = teacher
self.w = guidance_scale
# Freeze teacher
for p in teacher.parameters():
p.requires_grad = False
def compute_teacher_cfg(self, x_t, t, c):
"""Compute Teacher's CFG output"""
# Conditional prediction
eps_cond = self.teacher(x_t, t, c)
# Unconditional prediction (null condition)
null_c = torch.zeros_like(c)
eps_uncond = self.teacher(x_t, t, null_c)
# CFG combination
return eps_uncond + self.w * (eps_cond - eps_uncond)
def loss(self, x0, c):
B = x0.shape[0]
# Sample noise
t = torch.rand(B, device=x0.device)
noise = torch.randn_like(x0)
# Forward diffusion
sigma = t.view(B, 1, 1, 1)
x_t = x0 + sigma * noise
# Teacher CFG target
with torch.no_grad():
target = self.compute_teacher_cfg(x_t, t, c)
# Student prediction
pred = self.student(x_t, t, c)
return F.mse_loss(pred, target)Training Loop
def train_cfg_free(student, teacher, dataloader, epochs=100):
distiller = CFGFreeDistillation(student, teacher)
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-4)
for epoch in range(epochs):
for images, captions in dataloader:
# Text encoding
c = text_encoder(captions)
loss = distiller.loss(images, c)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}: loss = {loss.item():.4f}")7. Variational Score Distillation
Limitations
Limitations of simple distillation:
- Potential mode collapse
- Inherits Teacher's limitations
- Reduced diversity
VSD Approach
Using Score Distillation Sampling:
$$\nabla_\theta \mathcal{L}_\text{VSD} = \mathbb{E}\left[w(t)(\epsilon_\text{teacher} - \epsilon_\text{student}) \frac{\partial \epsilon_\text{student}}{\partial \theta}\right]$$
Adversarial Distillation
Adding GAN loss:
def adversarial_distillation_loss(student, teacher, discriminator, x0, c):
# Distillation loss
dist_loss = cfg_distillation_loss(student, teacher, x0, c)
# Generate sample
z = torch.randn_like(x0)
x_gen = student.sample(z, c)
# Adversarial loss
adv_loss = -discriminator(x_gen, c).mean()
return dist_loss + 0.1 * adv_loss8. Real-World Models
SDXL-Turbo
Stability AI's approach:
- Adversarial Diffusion Distillation (ADD)
- CFG-free + 1-4 step generation
- Uses GAN discriminator
SDXL-Lightning
ByteDance's approach:
- Progressive distillation
- CFG distillation
- Efficient training with LoRA
LCM (Latent Consistency Models)
Consistency distillation + CFG:
- Consistency loss for step reduction
- Internalized CFG effect
9. Quality Comparison
Quantitative Results
Nearly identical quality, overwhelming speed improvement!
Speed Comparison (A100)
40x faster!
10. Limitations and Future
Current Limitations
Future Directions
- Self-Distillation: Self-improvement without Teacher
- Continuous Guidance: Support arbitrary w values
- Multi-Modal Guidance: Handle multiple conditions simultaneously
Conclusion
Key Points of CFG-free Distillation:
- Distill Teacher's CFG effect into Student
- Eliminate 2x computational overhead
- Combined with step distillation for 10-100x speedup
- Enables real-time image generation
References
- Sauer, A., et al. "Adversarial Diffusion Distillation" (2023)
- Lin, S., et al. "SDXL-Lightning" (2024)
- Luo, S., et al. "Latent Consistency Models" (2023)
- Ho, J., Salimans, T. "Classifier-Free Diffusion Guidance" (2022)
- Meng, C., et al. "On Distillation of Guided Diffusion Models" (CVPR 2023)