Consistency Models: 1-Step 생성을 위한 새로운 패러다임

Consistency Models: 1-Step 생성을 위한 새로운 패러다임
Diffusion의 반복 샘플링 없이 단 한 번에. OpenAI의 혁신적 접근법.
TL;DR
- Consistency Models: 동일 trajectory 위의 모든 점을 같은 출력으로 매핑하는 모델
- Self-Consistency: $f(x_t, t) = f(x_{t'}, t')$ for all $t, t'$ on same trajectory
- 두 가지 학습법: Consistency Distillation (교사 필요) vs Consistency Training (교사 불필요)
- 결과: 1-step으로 고품질 생성, 필요시 multi-step으로 품질 향상 가능
1. 왜 Consistency Models인가?
Diffusion의 근본적 한계
Diffusion 모델은 반복적 샘플링이 필수입니다:
z ~ N(0,I) → x_T → x_{T-1} → ... → x_1 → x_0아무리 최적화해도:
- DDPM: 1000 스텝
- DDIM: 50-100 스텝
- DPM-Solver: 10-20 스텝
1-step은 불가능한가?
기존 접근법들의 문제
Consistency Models의 아이디어
핵심 관찰:
ODE trajectory 위의 모든 점은 **같은 데이터 포인트**로 수렴한다
따라서:
trajectory 위의 어떤 점에서 시작하든, **같은 출력**을 내는 함수를 학습하자!
2. Self-Consistency Property
정의
Consistency function $f: (x_t, t) \to x_0$는 다음을 만족:
$$f(x_t, t) = f(x_{t'}, t') \quad \forall t, t' \in [0, T]$$
단, $x_t$와 $x_{t'}$가 같은 ODE trajectory 위에 있을 때.
직관적 이해
Noise Data
z ─────●─────●─────●─────●─────> x_0
↓ ↓ ↓ ↓
f() f() f() f()
↓ ↓ ↓ ↓
└─────┴─────┴─────┘
모두 같은 x_0ODE를 따라가면 결국 같은 $x_0$에 도달하므로, 중간 어느 점에서든 바로 $x_0$를 예측할 수 있어야 합니다.
Boundary Condition
$t = 0$에서는 identity가 되어야 합니다:
$$f(x_0, 0) = x_0$$
이미 데이터에 있으면, 그대로 반환.
3. Consistency Model 아키텍처
기본 구조
Boundary condition을 만족하기 위한 설계:
$$f_\theta(x, t) = c_{\text{skip}}(t) \cdot x + c_{\text{out}}(t) \cdot F_\theta(x, t)$$
여기서:
- $F_\theta$: 신경망 (U-Net, DiT 등)
- $c_{\text{skip}}(t)$, $c_{\text{out}}(t)$: 시간에 따른 가중치
Skip Connection 설계
Boundary condition $f(x, 0) = x$를 만족하려면:
$$c_{\text{skip}}(0) = 1, \quad c_{\text{out}}(0) = 0$$
일반적인 선택:
$$c_{\text{skip}}(t) = \frac{\sigma_{\text{data}}^2}{\sigma_{\text{data}}^2 + t^2}$$
$$c_{\text{out}}(t) = \frac{t \cdot \sigma_{\text{data}}}{\sqrt{\sigma_{\text{data}}^2 + t^2}}$$
시간 임베딩
$t \to 0$ 근처에서 안정성을 위해 시간을 변환:
$$t' = \frac{1}{4} \log(t + 1)$$
4. Consistency Distillation (CD)
개념
미리 학습된 diffusion 모델을 교사로 사용:
- 교사 모델로 ODE trajectory 생성
- Consistency model이 trajectory의 다른 점들을 같은 출력으로 매핑하도록 학습
알고리즘
def consistency_distillation_loss(model, teacher, x0):
# Sample time
t = sample_timestep()
t_next = t - delta_t # one step earlier
# Add noise to get x_t
noise = torch.randn_like(x0)
x_t = add_noise(x0, t, noise)
# Teacher takes one ODE step: x_t -> x_{t_next}
with torch.no_grad():
x_t_next = teacher_ode_step(teacher, x_t, t, t_next)
# Consistency loss: f(x_t, t) should equal f(x_{t_next}, t_next)
pred_t = model(x_t, t)
pred_t_next = model(x_t_next, t_next).detach() # stop gradient
return F.mse_loss(pred_t, pred_t_next)Target Network (EMA)
안정적 학습을 위해 target network 사용:
$$\theta^- \leftarrow \mu \theta^- + (1-\mu) \theta$$
- $\theta$: 학습되는 모델
- $\theta^-$: EMA target (stop gradient)
- $\mu$: decay rate (0.999 등)
Loss Function
$$\mathcal{L}_{\text{CD}} = \mathbb{E}\left[d(f_\theta(x_{t_n}, t_n), f_{\theta^-}(x_{t_{n-1}}, t_{n-1}))\right]$$
여기서 $d$는 distance metric (L2, LPIPS 등).
5. Consistency Training (CT)
교사 없이 학습하기
Consistency Distillation은 교사 모델이 필요합니다. 하지만 교사 없이 직접 학습할 수도 있습니다!
핵심 아이디어
ODE를 정확히 풀지 않고, 무한소 스텝에서의 consistency를 강제:
$$\lim_{\Delta t \to 0} f(x_{t+\Delta t}, t+\Delta t) = f(x_t, t)$$
알고리즘
def consistency_training_loss(model, x0):
# Sample time
t = sample_timestep()
t_next = t - delta_t
# Add noise
noise = torch.randn_like(x0)
x_t = add_noise(x0, t, noise)
x_t_next = add_noise(x0, t_next, noise) # same noise!
# Consistency loss
pred_t = model(x_t, t)
pred_t_next = model(x_t_next, t_next).detach()
return F.mse_loss(pred_t, pred_t_next)핵심 차이: 교사의 ODE step 대신, 같은 noise로 다른 시간에서 샘플링.
왜 작동하는가?
$\Delta t \to 0$일 때:
$$x_{t+\Delta t} \approx x_t + \text{small perturbation}$$
이 perturbation은 ODE의 방향과 일치합니다. 따라서 무한소 단계에서 consistency를 강제하면, 전체 trajectory에서도 consistency가 성립합니다.
CD vs CT 비교
6. 샘플링
1-Step 샘플링
가장 단순한 방법:
def sample_one_step(model, z):
# z ~ N(0, I)
# 바로 x_0 예측
return model(z, T)끝! 반복 없이 한 번에 생성.
Multi-Step 샘플링 (품질 향상)
더 높은 품질을 원하면:
def sample_multi_step(model, z, timesteps):
"""
timesteps: [T, t_1, t_2, ..., 0] (decreasing)
"""
x = z
for i in range(len(timesteps) - 1):
t = timesteps[i]
t_next = timesteps[i + 1]
# Denoise to x_0
x_0 = model(x, t)
# Add noise back to t_next (if not last step)
if t_next > 0:
noise = torch.randn_like(x)
x = add_noise(x_0, t_next, noise)
return x_0원리:
- 현재 $x_t$에서 $x_0$ 예측
- $x_0$에 다시 노이즈 추가하여 $x_{t'}$ 생성
- 반복
이렇게 하면 denoising과 noise injection을 번갈아 수행하여 품질 향상.
7. 구현
Consistency Model 클래스
class ConsistencyModel(nn.Module):
def __init__(self, network, sigma_data=0.5):
super().__init__()
self.network = network
self.sigma_data = sigma_data
def c_skip(self, t):
return self.sigma_data**2 / (t**2 + self.sigma_data**2)
def c_out(self, t):
return t * self.sigma_data / torch.sqrt(t**2 + self.sigma_data**2)
def forward(self, x, t):
# Skip connection for boundary condition
c_skip = self.c_skip(t)
c_out = self.c_out(t)
if c_skip.dim() == 1:
c_skip = c_skip[:, None, None, None]
c_out = c_out[:, None, None, None]
F_x = self.network(x, t)
return c_skip * x + c_out * F_xConsistency Distillation 학습
class ConsistencyDistillation:
def __init__(self, model, teacher, ema_decay=0.999):
self.model = model
self.teacher = teacher
self.target_model = copy.deepcopy(model)
self.ema_decay = ema_decay
def ode_step(self, x, t, t_next):
"""One step of teacher ODE."""
# Using teacher to estimate velocity/score
with torch.no_grad():
score = self.teacher(x, t)
# Euler step
dt = t_next - t
x_next = x + score * dt
return x_next
def loss(self, x0):
B = x0.shape[0]
# Sample timesteps
t = torch.rand(B, device=x0.device) * (T - eps) + eps
t_next = t - delta_t
t_next = t_next.clamp(min=eps)
# Forward diffusion
noise = torch.randn_like(x0)
x_t = x0 + t[:, None, None, None] * noise
# Teacher ODE step
x_t_next = self.ode_step(x_t, t, t_next)
# Consistency loss
pred = self.model(x_t, t)
target = self.target_model(x_t_next, t_next)
return F.mse_loss(pred, target)
def update_target(self):
"""EMA update of target network."""
with torch.no_grad():
for p, p_target in zip(self.model.parameters(),
self.target_model.parameters()):
p_target.data.mul_(self.ema_decay).add_(
p.data, alpha=1 - self.ema_decay)Consistency Training 학습
class ConsistencyTraining:
def __init__(self, model, ema_decay=0.999):
self.model = model
self.target_model = copy.deepcopy(model)
self.ema_decay = ema_decay
def loss(self, x0):
B = x0.shape[0]
# Sample timesteps
t = torch.rand(B, device=x0.device) * (T - eps) + eps
t_next = t - delta_t
t_next = t_next.clamp(min=eps)
# Same noise for both timesteps!
noise = torch.randn_like(x0)
x_t = x0 + t[:, None, None, None] * noise
x_t_next = x0 + t_next[:, None, None, None] * noise
# Consistency loss
pred = self.model(x_t, t)
target = self.target_model(x_t_next, t_next)
return F.mse_loss(pred, target)8. Improved Consistency Training (iCT)
원본 CT의 문제점
- 학습 초기에 불안정
- 큰 $\Delta t$에서 오류 누적
- 수렴이 느림
개선 사항
- Adaptive $\Delta t$: 학습 진행에 따라 $\Delta t$ 감소
- Improved noise schedule: EDM 스타일 noise schedule
- Better loss weighting: 시간에 따른 가중치 조정
def adaptive_delta_t(step, total_steps):
"""Delta t decreases during training."""
progress = step / total_steps
return delta_t_max * (1 - progress) + delta_t_min * progress9. 실험 결과
CIFAR-10 FID
ImageNet 64x64
핵심 발견
- 1-step CD가 기존 distillation 방법들보다 우수
- 2-step으로 품질 크게 향상
- CT는 CD보다 약간 낮지만, 교사 불필요
10. Latent Consistency Models (LCM)
Stable Diffusion에 적용
Consistency Models를 latent space에서 학습:
# Encode to latent
z = vae.encode(image)
# Train consistency model in latent space
z_0_pred = consistency_model(z_t, t)
# Decode for visualization
image_pred = vae.decode(z_0_pred)LCM의 성과
- 4 스텝으로 Stable Diffusion 품질 달성
- 기존 대비 5-10배 속도 향상
- CFG와 호환 가능
LCM-LoRA
LoRA로 효율적 학습:
# Base SD model + LCM-LoRA
pipe = StableDiffusionPipeline.from_pretrained("...")
pipe.load_lora_weights("lcm-lora-sdv1-5")
# Fast generation
image = pipe(prompt, num_inference_steps=4).images[0]결론
Consistency Models의 핵심:
- Self-consistency property를 활용
- ODE trajectory를 직접 학습하지 않고, endpoint를 예측
- 1-step 생성이 가능하면서도 multi-step으로 품질 향상 가능
References
- Song, Y., et al. "Consistency Models" (ICML 2023)
- Song, Y., Dhariwal, P. "Improved Techniques for Training Consistency Models" (2023)
- Luo, S., et al. "Latent Consistency Models" (2023)
- Karras, T., et al. "Elucidating the Design Space of Diffusion-Based Generative Models" (NeurIPS 2022)