Rectified Flow: 1-Step 생성을 향한 경로 직선화

Rectified Flow: 1-Step 생성을 향한 경로 직선화
Flow Matching도 느리다면? Reflow로 경로를 펴서 1-step 생성까지 도달하는 방법.
TL;DR
- Rectified Flow: Flow Matching의 경로를 반복적으로 "직선화"하는 기법
- Reflow: 학습된 모델로 (noise, data) 쌍을 생성하고, 이 쌍으로 더 직선적인 경로를 재학습
- 핵심 이점: Reflow를 반복할수록 경로가 직선에 가까워지고, 최종적으로 1-step 생성 가능
- 실제 적용: Stable Diffusion 3, FLUX가 Rectified Flow 기반
1. 왜 Flow Matching만으로는 부족한가?
Flow Matching은 DDPM보다 훨씬 적은 스텝(10-50)으로 생성이 가능합니다. 하지만 여전히 한계가 있습니다.
Flow Matching의 한계
Flow Matching의 목표 속도장은:
$$v_t(x_t | x_0, z) = z - x_0$$
이론적으로는 상수 속도장이지만, 실제 학습에서는 marginal velocity field를 학습합니다:
$$v_t(x_t) = \mathbb{E}_{x_0, z | x_t}[z - x_0]$$
문제는 서로 다른 $(x_0, z)$ 쌍들이 같은 $x_t$를 지나갈 수 있다는 것입니다. 이 경로들이 교차(crossing)하면, 학습된 속도장은 이들의 평균이 되어 실제로는 곡선을 따라가게 됩니다.
경로 교차 문제
두 데이터 포인트 $x_0^{(1)}, x_0^{(2)}$와 두 노이즈 $z^{(1)}, z^{(2)}$가 있을 때:
$$x_t^{(1)} = (1-t)x_0^{(1)} + tz^{(1)}$$
$$x_t^{(2)} = (1-t)x_0^{(2)} + tz^{(2)}$$
어떤 $t$에서 $x_t^{(1)} = x_t^{(2)}$가 되면, 신경망은 두 방향의 평균을 예측하게 됩니다. 이것이 transport cost를 증가시키고 샘플링 스텝을 늘려야 하는 원인입니다.
2. Rectified Flow의 핵심 아이디어
Rectified Flow는 간단하지만 강력한 아이디어입니다:
**"학습된 flow로 (z, x₀) 쌍을 만들고, 이 쌍으로 다시 직선 경로를 학습하면 경로가 펴진다"**
Reflow 절차
- 초기 Flow Matching 학습: 랜덤 $(x_0, z)$ 쌍으로 기본 모델 $v_{\theta_0}$ 학습
- Coupling 생성: 학습된 모델로 noise $z$에서 시작해 data $\hat{x}_0$를 생성
- 이제 $(z, \hat{x}_0)$는 실제로 flow를 따라 연결된 쌍
- Reflow 학습: 새로운 모델 $v_{\theta_1}$을 $(z, \hat{x}_0)$ 쌍의 직선 경로로 학습
- 반복: 2-3을 반복할수록 경로가 더 직선화
수학적 표현
$k$번째 reflow 후의 coupling을 $\pi_k$라 하면:
$$\mathcal{L}_{\text{reflow}}^{(k)} = \mathbb{E}_{(x_0, z) \sim \pi_k, t} \left[ \| (z - x_0) - v_{\theta}(x_t, t) \|^2 \right]$$
여기서 $x_t = (1-t)x_0 + tz$이고, $\pi_k$는 $k$번째 모델이 생성한 coupling입니다.
3. 왜 Reflow가 경로를 직선화하는가?
직관적 이해
처음에는 랜덤 coupling $(x_0, z)$를 사용합니다. 이 경로들은 서로 교차할 수 있습니다.
하지만 학습된 flow $\phi_1$을 따라가면:
- $z$에서 출발한 경로는 특정 $\hat{x}_0$에 도착
- 이 $(z, \hat{x}_0)$ 쌍은 이미 flow를 따라 연결되어 있음
- 따라서 이 쌍들의 직선 경로는 덜 교차함
Transport Cost 감소
Reflow의 핵심은 transport cost를 줄이는 것입니다:
$$\text{Cost}(\pi) = \mathbb{E}_{(x_0, z) \sim \pi} \left[ \| z - x_0 \|^2 \right]$$
Reflow를 반복하면:
$$\text{Cost}(\pi_0) \geq \text{Cost}(\pi_1) \geq \text{Cost}(\pi_2) \geq \cdots$$
경로가 직선화되면서 transport cost가 감소합니다.
이론적 보장
논문에서 증명된 중요한 성질:
- Causality: Reflow된 coupling은 인과적(causal)입니다. 즉, $z$가 주어지면 $x_0$가 결정됨
- Straightness: Reflow를 무한히 반복하면 경로가 완전히 직선이 됨
- 1-step 가능성: 완전히 직선화되면 Euler 1-step으로 정확한 샘플링 가능
4. 1-Step Distillation
Reflow만으로도 경로가 직선화되지만, 실용적으로 1-step 생성을 위해서는 distillation이 필요합니다.
Progressive Distillation
스텝 수를 점진적으로 줄이는 방법:
- Teacher 모델: N steps
- Student 모델: N/2 steps로 teacher 출력 모방
- 반복하여 1-step까지 도달
$$\mathcal{L}_{\text{distill}} = \mathbb{E}_{z} \left[ \| \phi_{\text{teacher}}(z) - G_{\theta}(z) \|^2 \right]$$
Direct Distillation
Rectified Flow의 장점은 경로가 이미 직선에 가깝기 때문에 직접 1-step distillation이 가능하다는 것:
$$\mathcal{L}_{\text{1-step}} = \mathbb{E}_{z} \left[ \| x_0 - (z - v_{\theta}(z, 1)) \|^2 \right]$$
여기서 $v_{\theta}(z, 1)$은 $t=1$에서의 속도 예측입니다.
5. 구현
Reflow 학습
class RectifiedFlow:
def __init__(self, model):
self.model = model
def loss(self, x0, z):
"""Reflow loss with fixed coupling."""
t = torch.rand(x0.shape[0], device=x0.device)
# Linear interpolation
x_t = (1 - t[:, None]) * x0 + t[:, None] * z
# Target velocity (straight line)
v_target = z - x0
# Predicted velocity
v_pred = self.model(x_t, t)
return F.mse_loss(v_pred, v_target)
@torch.no_grad()
def sample(self, z, n_steps=1):
"""Sample with Euler method."""
x = z
dt = 1.0 / n_steps
for i in range(n_steps):
t = 1.0 - i * dt
t_batch = torch.full((x.shape[0],), t, device=x.device)
v = self.model(x, t_batch)
x = x - v * dt
return x
@torch.no_grad()
def generate_coupling(self, z, n_steps=50):
"""Generate (z, x0) coupling pairs."""
x0 = self.sample(z, n_steps=n_steps)
return z, x0Reflow 학습 루프
def train_reflow(data, n_reflows=3, n_epochs=500):
"""Train with multiple reflow iterations."""
# Initial Flow Matching
model = create_model()
rf = RectifiedFlow(model)
# Train on random coupling
for epoch in range(n_epochs):
x0 = sample_data(data)
z = torch.randn_like(x0)
loss = rf.loss(x0, z)
loss.backward()
optimizer.step()
# Reflow iterations
for k in range(n_reflows):
print(f"Reflow {k+1}")
# Generate coupling from current model
z_all = torch.randn(len(data), dim)
z_all, x0_all = rf.generate_coupling(z_all, n_steps=50)
# Train new model on this coupling
new_model = create_model()
new_rf = RectifiedFlow(new_model)
for epoch in range(n_epochs):
idx = torch.randperm(len(x0_all))[:batch_size]
loss = new_rf.loss(x0_all[idx], z_all[idx])
loss.backward()
optimizer.step()
rf = new_rf
return rf1-Step Distillation
def distill_to_one_step(teacher_rf, student_model, data, n_epochs=1000):
"""Distill to 1-step generator."""
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)
for epoch in range(n_epochs):
z = torch.randn(batch_size, dim)
# Teacher generates target
with torch.no_grad():
x0_teacher = teacher_rf.sample(z, n_steps=10)
# Student predicts in 1 step
# x0 = z - v(z, t=1)
v_pred = student_model(z, torch.ones(batch_size))
x0_student = z - v_pred
loss = F.mse_loss(x0_student, x0_teacher)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return student_model6. Stable Diffusion 3와 FLUX
SD3의 Rectified Flow 적용
Stable Diffusion 3는 Rectified Flow를 채택했습니다:
- MMDiT 아키텍처: 텍스트와 이미지를 동시에 처리하는 Multimodal DiT
- Rectified Flow: 기존 DDPM 대신 직선 경로 학습
- 결과: 동일 품질에서 더 적은 스텝 필요
FLUX의 발전
FLUX (by Black Forest Labs)는 SD3를 더 발전시켰습니다:
- Guidance Distillation: CFG를 모델에 내재화
- 더 적은 스텝: 4-8 스텝으로 고품질 생성
- FLUX.1-schnell: 1-4 스텝 생성 가능한 distilled 버전
왜 Rectified Flow인가?
기존 Stable Diffusion (DDPM 기반)에서 전환한 이유:
7. Reflow 횟수와 품질
몇 번의 Reflow가 필요한가?
실험적으로:
- 1-Reflow: 상당한 직선화, 10-step으로 좋은 품질
- 2-Reflow: 더 직선화, 5-step 가능
- 3-Reflow: 거의 직선, 1-2 step 가능
하지만 reflow를 많이 할수록:
- 학습 시간 증가
- Coupling 생성에 시간 소요
- 수렴 속도 감소 가능
실용적 선택
대부분의 경우 1-2회 reflow + distillation이 가장 효율적입니다.
8. 한계와 주의사항
Coupling 품질 의존성
Reflow는 이전 모델의 생성 품질에 의존합니다:
- 초기 모델이 나쁘면 → 나쁜 coupling → 나쁜 reflow 결과
- 해결책: 초기 Flow Matching을 충분히 학습
Mode Collapse 위험
Reflow를 너무 많이 하면:
- Coupling이 특정 모드에 집중될 수 있음
- 다양성(diversity) 감소 가능
- 해결책: 적절한 reflow 횟수 선택, regularization
계산 비용
각 reflow 단계마다:
- 전체 데이터셋에 대해 coupling 생성 필요
- 새 모델 학습 필요
- 총 비용 = (1 + n_reflows) × 기본 학습 비용
결론
Rectified Flow는 "경로를 펴면 빨라진다"는 직관적 아이디어를 실현한 방법입니다. Stable Diffusion 3와 FLUX의 성공이 이 접근법의 실용성을 증명했습니다.
References
- Liu, X., et al. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow" (ICLR 2023)
- Esser, P., et al. "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" (Stable Diffusion 3, 2024)
- Lipman, Y., et al. "Flow Matching for Generative Modeling" (ICLR 2023)
- Salimans, T. & Ho, J. "Progressive Distillation for Fast Sampling of Diffusion Models" (ICLR 2022)