Stable Diffusion 3 & FLUX: MMDiT 아키텍처 완전 분석

Stable Diffusion 3 & FLUX: MMDiT 아키텍처 완전 분석
U-Net을 버리고 Transformer로. Text와 Image를 동등하게 처리하는 새로운 패러다임.
TL;DR
- MMDiT (Multimodal DiT): 텍스트와 이미지를 하나의 Transformer에서 동시 처리
- Rectified Flow 채택: DDPM 대신 직선 경로로 빠른 생성
- FLUX 발전: Guidance Distillation으로 CFG 없이 4-8 스텝 생성
- 핵심 혁신: Text-Image 간 양방향 attention으로 더 정확한 프롬프트 따르기
1. 왜 U-Net을 버렸는가?
U-Net의 한계
Stable Diffusion 1.x/2.x는 U-Net 기반이었습니다:
Text Encoder (CLIP) → Cross-Attention → U-Net → Image문제점:
- 일방향 정보 흐름: 텍스트 → 이미지만 가능, 이미지 → 텍스트 피드백 없음
- Cross-attention 병목: 텍스트 정보가 특정 레이어에서만 주입
- Scaling 한계: U-Net은 모델 크기 증가 시 성능 향상이 수확체감
DiT의 등장
DiT (Diffusion Transformer)가 보여준 것:
- Transformer는 scaling law를 따름
- 모델이 클수록 일관되게 FID 개선
- 하지만 DiT도 텍스트 처리는 cross-attention 방식
MMDiT: 진정한 Multimodal
SD3의 MMDiT는 텍스트와 이미지를 동등한 시퀀스로 처리:
[Text Tokens] + [Image Tokens] → Joint Transformer → [Text'] + [Image']양방향 attention으로 텍스트가 이미지를 보고, 이미지도 텍스트를 봅니다.
2. MMDiT 아키텍처 상세
입력 처리
이미지 입력:
- VAE Encoder로 latent 추출:
(H, W, 3)→(h, w, 16) - Patchify:
(h, w, 16)→(N_img, D) - Position embedding 추가
텍스트 입력:
- 세 가지 텍스트 인코더 사용:
- CLIP-L (OpenAI)
- CLIP-G (OpenCLIP)
- T5-XXL (Google)
- Pooled + Sequence embeddings 결합
(N_txt, D)형태로 변환
Joint Attention Block
핵심은 MM-DiT Block입니다:
class MMDiTBlock(nn.Module):
def __init__(self, dim):
self.norm1_img = AdaLayerNorm(dim)
self.norm1_txt = AdaLayerNorm(dim)
self.attn = JointAttention(dim)
self.norm2_img = AdaLayerNorm(dim)
self.norm2_txt = AdaLayerNorm(dim)
self.ff_img = FeedForward(dim)
self.ff_txt = FeedForward(dim)
def forward(self, img, txt, timestep):
# Separate normalization
img_norm = self.norm1_img(img, timestep)
txt_norm = self.norm1_txt(txt, timestep)
# Joint attention (핵심!)
img_attn, txt_attn = self.attn(img_norm, txt_norm)
img = img + img_attn
txt = txt + txt_attn
# Separate feedforward
img = img + self.ff_img(self.norm2_img(img, timestep))
txt = txt + self.ff_txt(self.norm2_txt(txt, timestep))
return img, txtJoint Attention 메커니즘
class JointAttention(nn.Module):
def forward(self, img, txt):
# 이미지와 텍스트를 concat
x = torch.cat([img, txt], dim=1)
# Q, K, V 계산
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
# Self-attention (모든 토큰이 서로를 봄)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = attn @ v
# 다시 분리
img_out, txt_out = out.split([img.shape[1], txt.shape[1]], dim=1)
return img_out, txt_out핵심 포인트:
- 이미지 토큰이 텍스트 토큰에 attend
- 텍스트 토큰이 이미지 토큰에 attend
- 양방향 정보 흐름으로 더 정확한 text-image alignment
AdaLN (Adaptive Layer Normalization)
timestep 정보를 주입하는 방법:
class AdaLayerNorm(nn.Module):
def __init__(self, dim):
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
self.proj = nn.Linear(dim, dim * 2)
def forward(self, x, timestep_emb):
# timestep에서 scale, shift 예측
scale, shift = self.proj(timestep_emb).chunk(2, dim=-1)
# Adaptive normalization
x = self.norm(x)
x = x * (1 + scale) + shift
return x3. Rectified Flow in SD3
SD3는 DDPM 대신 Rectified Flow를 사용합니다.
왜 Rectified Flow인가?
SD3의 Flow Formulation
def flow_matching_loss(model, x0, text_emb):
# Sample time
t = torch.rand(x0.shape[0])
# Sample noise
z = torch.randn_like(x0)
# Linear interpolation
x_t = (1 - t) * x0 + t * z
# Target velocity
v_target = z - x0
# Predict velocity
v_pred = model(x_t, t, text_emb)
return F.mse_loss(v_pred, v_target)Logit-Normal Sampling
SD3의 특별한 점: timestep을 uniform이 아닌 logit-normal 분포에서 샘플링
def logit_normal_sample(batch_size, m=0.0, s=1.0):
"""중간 timestep에 더 집중"""
u = torch.randn(batch_size) * s + m
t = torch.sigmoid(u) # (0, 1) 범위로 변환
return t이유: 중간 timestep이 학습에 더 중요하기 때문
4. FLUX: SD3의 진화
FLUX는 Black Forest Labs (SD3 개발자들이 설립)에서 만든 모델입니다.
FLUX vs SD3 비교
FLUX 변형들
- FLUX.1-pro: 최고 품질, API only
- FLUX.1-dev: 연구/개발용, 오픈 웨이트
- FLUX.1-schnell: 1-4 스텝 생성, 가장 빠름
Guidance Distillation
FLUX.1-schnell의 핵심 기술:
기존 CFG (Classifier-Free Guidance):
# 추론 시 2배 계산 필요
pred_uncond = model(x_t, t, null_text)
pred_cond = model(x_t, t, text)
pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)Guidance Distillation 후:
# 1번의 forward pass로 CFG 효과
pred = model(x_t, t, text) # guidance가 내재화됨학습 방법:
def guidance_distillation_loss(student, teacher, x_t, t, text):
# Teacher: CFG 적용
with torch.no_grad():
pred_uncond = teacher(x_t, t, null_text)
pred_cond = teacher(x_t, t, text)
target = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
# Student: 단일 forward
pred = student(x_t, t, text)
return F.mse_loss(pred, target)5. 텍스트 인코더 전략
SD3의 Triple Text Encoder
SD3는 세 가지 텍스트 인코더를 사용:
class TripleTextEncoder:
def __init__(self):
self.clip_l = CLIPTextModel("openai/clip-vit-large")
self.clip_g = CLIPTextModel("laion/CLIP-ViT-bigG")
self.t5 = T5EncoderModel("google/t5-v1_1-xxl")
def encode(self, text):
# CLIP embeddings (pooled)
clip_l_pooled, clip_l_seq = self.clip_l(text)
clip_g_pooled, clip_g_seq = self.clip_g(text)
# T5 embedding (sequence only)
t5_seq = self.t5(text)
# Pooled: conditioning용
pooled = torch.cat([clip_l_pooled, clip_g_pooled], dim=-1)
# Sequence: cross-attention용
seq = torch.cat([clip_l_seq, clip_g_seq, t5_seq], dim=1)
return pooled, seq왜 세 개인가?
T5의 추가로 긴 프롬프트와 복잡한 관계 이해가 크게 향상되었습니다.
FLUX의 텍스트 처리
FLUX는 더 단순화:
- CLIP-L + T5-XXL 조합
- T5에 더 의존 (더 긴 컨텍스트 활용)
6. VAE 개선
SD3의 16채널 VAE
기존 SD 1.x/2.x: 4채널 latent
SD3/FLUX: 16채널 latent
# SD 1.x/2.x
latent = vae.encode(image) # (B, 4, H/8, W/8)
# SD3/FLUX
latent = vae.encode(image) # (B, 16, H/8, W/8)장점:
- 더 많은 정보 보존
- 세밀한 디테일 재구성
- 텍스트 렌더링 품질 향상
단점:
- 메모리 사용량 증가
- 계산량 증가
7. 실제 사용 예시
Diffusers로 SD3 사용
from diffusers import StableDiffusion3Pipeline
import torch
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium",
torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
image = pipe(
prompt="A cat holding a sign that says 'Hello World'",
num_inference_steps=28,
guidance_scale=7.0,
).images[0]Diffusers로 FLUX 사용
from diffusers import FluxPipeline
import torch
# FLUX.1-schnell (빠른 버전)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
image = pipe(
prompt="A cat holding a sign that says 'Hello World'",
num_inference_steps=4, # 4 스텝만!
guidance_scale=0.0, # CFG 불필요
).images[0]메모리 최적화
# CPU offload
pipe.enable_model_cpu_offload()
# Attention slicing
pipe.enable_attention_slicing()
# VAE tiling (고해상도용)
pipe.enable_vae_tiling()8. 성능 비교
텍스트 렌더링 능력
SD3/FLUX의 가장 큰 개선점: 텍스트를 이미지에 정확히 렌더링
프롬프트 따르기
복잡한 프롬프트 테스트: "A red cube on top of a blue sphere, with a green pyramid to the left"
생성 속도 (A100 기준)
9. 한계와 주의사항
메모리 요구사항
라이선스
- SD3: Stability AI Community License (상업적 제한)
- FLUX.1-dev: 연구/개발용 (상업적 제한)
- FLUX.1-schnell: Apache 2.0 (상업적 사용 가능)
알려진 문제
- 인체 해부학: 여전히 손가락 등에서 오류 발생
- 텍스트 일관성: 긴 텍스트에서 가끔 오류
- 스타일 다양성: 특정 스타일에 편향될 수 있음
결론
SD3와 FLUX는 U-Net에서 Transformer로, DDPM에서 Rectified Flow로의 패러다임 전환을 보여줍니다. MMDiT의 양방향 attention은 text-image alignment를 크게 개선했고, Rectified Flow는 빠른 생성을 가능하게 했습니다.
References
- Esser, P., et al. "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" (SD3 Paper, 2024)
- Black Forest Labs. "FLUX.1 Technical Report" (2024)
- Peebles, W. & Xie, S. "Scalable Diffusion Models with Transformers" (DiT, 2023)
- Liu, X., et al. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow" (2023)