Consistency Models: A New Paradigm for 1-Step Generation

Consistency Models: A New Paradigm for 1-Step Generation
Single-step generation without iterative sampling. OpenAI's innovative approach.
TL;DR
- Consistency Models: Map all points on the same trajectory to the same output
- Self-Consistency: $f(x_t, t) = f(x_{t'}, t')$ for all $t, t'$ on same trajectory
- Two Training Methods: Consistency Distillation (requires teacher) vs Consistency Training (no teacher)
- Result: High-quality 1-step generation, with optional multi-step for better quality
1. Why Consistency Models?
The Fundamental Limitation of Diffusion
Diffusion models require iterative sampling:
z ~ N(0,I) → x_T → x_{T-1} → ... → x_1 → x_0No matter how optimized:
- DDPM: 1000 steps
- DDIM: 50-100 steps
- DPM-Solver: 10-20 steps
Is 1-step impossible?
Problems with Existing Approaches
The Consistency Models Idea
Key observation:
All points on an ODE trajectory converge to the **same data point**
Therefore:
Learn a function that outputs the **same result** regardless of starting point on trajectory!
2. Self-Consistency Property
Definition
A consistency function $f: (x_t, t) \to x_0$ satisfies:
$$f(x_t, t) = f(x_{t'}, t') \quad \forall t, t' \in [0, T]$$
when $x_t$ and $x_{t'}$ are on the same ODE trajectory.
Intuitive Understanding
Noise Data
z ─────●─────●─────●─────●─────> x_0
↓ ↓ ↓ ↓
f() f() f() f()
↓ ↓ ↓ ↓
└─────┴─────┴─────┘
All same x_0Following the ODE leads to the same $x_0$, so predicting $x_0$ directly from any intermediate point should be possible.
Boundary Condition
At $t = 0$, should be identity:
$$f(x_0, 0) = x_0$$
If already at data, return as-is.
3. Consistency Model Architecture
Basic Structure
Design to satisfy boundary condition:
$$f_\theta(x, t) = c_{\text{skip}}(t) \cdot x + c_{\text{out}}(t) \cdot F_\theta(x, t)$$
Where:
- $F_\theta$: Neural network (U-Net, DiT, etc.)
- $c_{\text{skip}}(t)$, $c_{\text{out}}(t)$: Time-dependent weights
Skip Connection Design
To satisfy $f(x, 0) = x$:
$$c_{\text{skip}}(0) = 1, \quad c_{\text{out}}(0) = 0$$
Common choice:
$$c_{\text{skip}}(t) = \frac{\sigma_{\text{data}}^2}{\sigma_{\text{data}}^2 + t^2}$$
$$c_{\text{out}}(t) = \frac{t \cdot \sigma_{\text{data}}}{\sqrt{\sigma_{\text{data}}^2 + t^2}}$$
Time Embedding
For stability near $t \to 0$, transform time:
$$t' = \frac{1}{4} \log(t + 1)$$
4. Consistency Distillation (CD)
Concept
Use a pre-trained diffusion model as teacher:
- Generate ODE trajectory with teacher
- Train consistency model to map different points on trajectory to same output
Algorithm
def consistency_distillation_loss(model, teacher, x0):
# Sample time
t = sample_timestep()
t_next = t - delta_t # one step earlier
# Add noise to get x_t
noise = torch.randn_like(x0)
x_t = add_noise(x0, t, noise)
# Teacher takes one ODE step: x_t -> x_{t_next}
with torch.no_grad():
x_t_next = teacher_ode_step(teacher, x_t, t, t_next)
# Consistency loss: f(x_t, t) should equal f(x_{t_next}, t_next)
pred_t = model(x_t, t)
pred_t_next = model(x_t_next, t_next).detach() # stop gradient
return F.mse_loss(pred_t, pred_t_next)Target Network (EMA)
For stable training, use target network:
$$\theta^- \leftarrow \mu \theta^- + (1-\mu) \theta$$
- $\theta$: Training model
- $\theta^-$: EMA target (stop gradient)
- $\mu$: Decay rate (e.g., 0.999)
Loss Function
$$\mathcal{L}_{\text{CD}} = \mathbb{E}\left[d(f_\theta(x_{t_n}, t_n), f_{\theta^-}(x_{t_{n-1}}, t_{n-1}))\right]$$
Where $d$ is a distance metric (L2, LPIPS, etc.).
5. Consistency Training (CT)
Learning Without a Teacher
Consistency Distillation requires a teacher model. But we can also learn without a teacher!
Key Idea
Instead of solving ODE exactly, enforce consistency at infinitesimal steps:
$$\lim_{\Delta t \to 0} f(x_{t+\Delta t}, t+\Delta t) = f(x_t, t)$$
Algorithm
def consistency_training_loss(model, x0):
# Sample time
t = sample_timestep()
t_next = t - delta_t
# Add noise
noise = torch.randn_like(x0)
x_t = add_noise(x0, t, noise)
x_t_next = add_noise(x0, t_next, noise) # same noise!
# Consistency loss
pred_t = model(x_t, t)
pred_t_next = model(x_t_next, t_next).detach()
return F.mse_loss(pred_t, pred_t_next)Key difference: Instead of teacher ODE step, sample at different times with same noise.
Why Does It Work?
When $\Delta t \to 0$:
$$x_{t+\Delta t} \approx x_t + \text{small perturbation}$$
This perturbation aligns with ODE direction. Thus enforcing consistency at infinitesimal steps implies consistency along entire trajectory.
CD vs CT Comparison
6. Sampling
1-Step Sampling
The simplest method:
def sample_one_step(model, z):
# z ~ N(0, I)
# Directly predict x_0
return model(z, T)Done! Generation without iteration.
Multi-Step Sampling (Quality Improvement)
For higher quality:
def sample_multi_step(model, z, timesteps):
"""
timesteps: [T, t_1, t_2, ..., 0] (decreasing)
"""
x = z
for i in range(len(timesteps) - 1):
t = timesteps[i]
t_next = timesteps[i + 1]
# Denoise to x_0
x_0 = model(x, t)
# Add noise back to t_next (if not last step)
if t_next > 0:
noise = torch.randn_like(x)
x = add_noise(x_0, t_next, noise)
return x_0Principle:
- Predict $x_0$ from current $x_t$
- Add noise back to get $x_{t'}$
- Repeat
This alternates denoising and noise injection for quality improvement.
7. Implementation
Consistency Model Class
class ConsistencyModel(nn.Module):
def __init__(self, network, sigma_data=0.5):
super().__init__()
self.network = network
self.sigma_data = sigma_data
def c_skip(self, t):
return self.sigma_data**2 / (t**2 + self.sigma_data**2)
def c_out(self, t):
return t * self.sigma_data / torch.sqrt(t**2 + self.sigma_data**2)
def forward(self, x, t):
# Skip connection for boundary condition
c_skip = self.c_skip(t)
c_out = self.c_out(t)
if c_skip.dim() == 1:
c_skip = c_skip[:, None, None, None]
c_out = c_out[:, None, None, None]
F_x = self.network(x, t)
return c_skip * x + c_out * F_xConsistency Distillation Training
class ConsistencyDistillation:
def __init__(self, model, teacher, ema_decay=0.999):
self.model = model
self.teacher = teacher
self.target_model = copy.deepcopy(model)
self.ema_decay = ema_decay
def ode_step(self, x, t, t_next):
"""One step of teacher ODE."""
with torch.no_grad():
score = self.teacher(x, t)
dt = t_next - t
x_next = x + score * dt
return x_next
def loss(self, x0):
B = x0.shape[0]
# Sample timesteps
t = torch.rand(B, device=x0.device) * (T - eps) + eps
t_next = t - delta_t
t_next = t_next.clamp(min=eps)
# Forward diffusion
noise = torch.randn_like(x0)
x_t = x0 + t[:, None, None, None] * noise
# Teacher ODE step
x_t_next = self.ode_step(x_t, t, t_next)
# Consistency loss
pred = self.model(x_t, t)
target = self.target_model(x_t_next, t_next)
return F.mse_loss(pred, target)
def update_target(self):
"""EMA update of target network."""
with torch.no_grad():
for p, p_target in zip(self.model.parameters(),
self.target_model.parameters()):
p_target.data.mul_(self.ema_decay).add_(
p.data, alpha=1 - self.ema_decay)Consistency Training
class ConsistencyTraining:
def __init__(self, model, ema_decay=0.999):
self.model = model
self.target_model = copy.deepcopy(model)
self.ema_decay = ema_decay
def loss(self, x0):
B = x0.shape[0]
# Sample timesteps
t = torch.rand(B, device=x0.device) * (T - eps) + eps
t_next = t - delta_t
t_next = t_next.clamp(min=eps)
# Same noise for both timesteps!
noise = torch.randn_like(x0)
x_t = x0 + t[:, None, None, None] * noise
x_t_next = x0 + t_next[:, None, None, None] * noise
# Consistency loss
pred = self.model(x_t, t)
target = self.target_model(x_t_next, t_next)
return F.mse_loss(pred, target)8. Improved Consistency Training (iCT)
Problems with Original CT
- Unstable early in training
- Error accumulation with large $\Delta t$
- Slow convergence
Improvements
- Adaptive $\Delta t$: Decrease $\Delta t$ during training
- Improved noise schedule: EDM-style noise schedule
- Better loss weighting: Time-dependent weighting
def adaptive_delta_t(step, total_steps):
"""Delta t decreases during training."""
progress = step / total_steps
return delta_t_max * (1 - progress) + delta_t_min * progress9. Experimental Results
CIFAR-10 FID
ImageNet 64x64
Key Findings
- 1-step CD outperforms existing distillation methods
- 2-step significantly improves quality
- CT slightly lower than CD but requires no teacher
10. Latent Consistency Models (LCM)
Application to Stable Diffusion
Train Consistency Models in latent space:
# Encode to latent
z = vae.encode(image)
# Train consistency model in latent space
z_0_pred = consistency_model(z_t, t)
# Decode for visualization
image_pred = vae.decode(z_0_pred)LCM Achievements
- 4 steps to match Stable Diffusion quality
- 5-10x speedup compared to original
- Compatible with CFG
LCM-LoRA
Efficient training with LoRA:
# Base SD model + LCM-LoRA
pipe = StableDiffusionPipeline.from_pretrained("...")
pipe.load_lora_weights("lcm-lora-sdv1-5")
# Fast generation
image = pipe(prompt, num_inference_steps=4).images[0]Conclusion
Key to Consistency Models:
- Leverage self-consistency property
- Predict endpoint instead of learning ODE trajectory directly
- Enable 1-step generation while allowing multi-step for quality improvement
References
- Song, Y., et al. "Consistency Models" (ICML 2023)
- Song, Y., Dhariwal, P. "Improved Techniques for Training Consistency Models" (2023)
- Luo, S., et al. "Latent Consistency Models" (2023)
- Karras, T., et al. "Elucidating the Design Space of Diffusion-Based Generative Models" (NeurIPS 2022)