Axel Sauer, Dominik Lorenz, Andreas Blattmann, Robin Rombach
- Stability AI
Official
https://stability.ai/news/stability-ai-sdxl-turbo
https://stability.ai/research/adversarial-diffusion-distillation
Arxiv
https://arxiv.org/abs/2311.17042
공식 Github
https://github.com/Stability-AI/generative-models
Huggingface
https://huggingface.co/stabilityai/sdxl-turbo
https://clipdrop.co/stable-diffusion-turbo
큰 규모로 데이터로 학습시킨 SDXL이라는 foundational image diffusion model에 있어 생성 속도는 항상 고민이었다.
여기서 LCM, LCM-LoRA과 같이 훨씬 적은 step으로도 원 SDXL에 비견할만한 품질의 inference를 사람들은 희망했다.
이에 Stability AI의 원 구성원들이 기대에 부응하듯 그 안을 제안했고, 상당한 진전을 보여 논문으로 소개하고자 한다.
훨씬 적은 step으로 이미지를 생성할 수 있는 방법론을 제안하는데 이는 자체적 개발보다 기존에 있던 2가지 기술이 크게 녹여들어있다.
1) Score Distillation (22.09 Dreamfusion에서 사용한 Score Distillation Sampling (SDS) 참조)
2) Adversarial Loss (GAN 참조)
자세한 학습 방법은 아래 Figure 2를 참조하길 바란다.
결과적으로 이를 통해 적은 step만으로도 좋은 성과를 달성할 수 있음을 보인다.
Diffusion Models이 잘 될 수 있었던 주요 요인 중 하나는 scalability와 반복적인 특성때문이지만, 반복하는 inference 방법은 상단한 sampling step을 요구한다.
이에 반해 GAN은 단일 step에서 생성할 수 있는 장점을 가지는데 이 속도를 본 방법에 적용할 수 있을지를 저자들은 고민했다.
결과적으로 저자들은 Adversarial Diffusion Distillation (ADD)라는 새로운 접근법을 제안했는데, 이는 2가지 학습 objective에 대한 조합으로 1~4 sampling step만으로 상당한 품질 결과를 도출할 수 있었다.
1) adversarial loss
2) distillation loss (22.09 Dreamfusion에서 사용한 Score Distillation Sampling (SDS) 참조)
Inference 동안 저자들은 CFG를 사용하지 않아 메모리 요구량을 감소시켰다.
구체적 방법은 3. Method에서 설명하도록 하겠다.
Distillation 기반 방법들 소개
progressive distillation, guidance distillation 연구분야 소개
Consistency Models
LCMs, LCM-LoRA, InstaFlow
그러나 이들은 blur하고 artifact 생김.
Score Distillation Sampling은 Score Jacobian Chaining이라 불리는데 T2I 모델의 지식을 3D 합성 모델에 distill하는데 사용함.
최근 Score-based model과 GAN 간의 강한 상관관계가 있음을 보이는데 여기서 영감을 받아 진행함.
본 방법에서는
1) pretrained diffusion model의 gradient를 score distillation objective에 사용. (규모있는 GAN에서의 text alignment 성능향상을 위해)
2) 초기 모델의 initialize를 pretrained diffusion model의 weight 사용.
3) GAN 학습에 있어 decoder만 사용하는 것 대신 표준 diffusion model framework 채택.
3개의 network가 필요
1) ADD student (trainable $\theta$ : pretrained UNet-DM weight로 초기화)
2) Discriminator (trainable $\phi$)
3) DM teacher (frozen $\psi$)
학습 동안의 forward diffusion process $x_s = \alpha_s x_0 + \sigma_s \epsilon$일 때,
실제 데이터 세트 이미지 $x_0$으로부터 noised data $x_s$가 생성되고 이를 ADD-studen를 통해 $\hat{x}_\theta (x_s , s)$ sample이 생성된다.
훈련동안, N=4이고 zero-terminal SNR 강제
Adversarial objective에서는 생성된 sample $\hat{x}_\theta$와 실제 이미지 $x_0$간의 discriminator 결과에 사용
Distillation을 위해서 Student sample $\hat{x}_\theta$를 teacher forward process에 넣은 다음, DM-teacher 모델의 결과 $\hat{x}_\psi (\hat{x}_{\theta, t},t)$를 도출한다. 여기서 distillation loss를 진행.
다만, latent space가 아닌 pixel space에서 distillation loss를 계산하는 것이 더 안정적인 gradient를 내보내기에 그렇게 사용함.
Discriminator에 있어 Stylegan-t에서 제안한 디자인과 훈련 절차를 따른다.
frozen pretrained feature network $\mathcal{F}$ : ViT
trainable lightweight discriminator heads $\mathcal{D}_{\phi, k}$ : 다른 layer $\mathcal{F}_k$의 feature에 적용
성능 향상을 위해 discriminator는 projection을 통해 추가 정보(text embedding, image) conditioning 진행할 수 있다. 특히, ADD-student에 입력 이미지를 효과적으로 사용되는 것을 받는 것을 위해 사용될 수 있기에 실제로 image embedding $c_{img}$ 추출을 위해 추가 feature network를 사용함.
hinge loss 사용 (https://en.wikipedia.org/wiki/Hinge_loss)
R1 : R1 gradient penalty - 128 px 이상에서 효과적임을 확인.
sg : stop-gradient operation
d : distance metric - ADD-student에 의해 생성된 sample $\hat{x}_\theta$와 DM-teacher 결과 $\hat{x}_\psi$간의 불일치 계산. distance function (L2 loss)
c(t) : weighting function
2가지 옵션 고려 : exponential weighting, score distillation sampling (SDS) weighting (선택)
장점 1) reconstruction target에 대한 직접적 시각화가 가능
장점 2) 연속적인 denoising step을 자연스럽게 실행 가능
SDS 변량 평가를 위해 noise-free score distillation (NFSD) objective 평가 진행함.
2 모델 훈련
ADD-M (860M parameters) - Stable Diffusion (SD) 2.1 backbone, ADD-XL (3.1B parameters)
실험 수행 해상도 : 512x512 pixels
distillation weighting factor λ = 2.5
R1 penalty strength : $10^{-5}$
FID, CLIP score로 정량 평가 진행.
Discriminator feature networks (Table 1a) : ViT에 DINOv2 조합이 좋았음
Discriminator conditioning (Table 1b) : $c_{text}$, $c_{img}$ 다 활용하는 것이 효과적
Student pretraining (Table 1c) : ADD-Student에 Pretrained generator 사용이 더 좋음
Loss terms (Table 1d) : adversarial loss와 결합이 성능향상
Teacher type (Table 1e) : 꼭 크다고 좋은 결과를 나타내지는 않음
Teacher steps (Table 1f) : 더 많은 step이 꼭 좋은 결과를 도출하지 않음
ELO Score (https://en.wikipedia.org/wiki/Elo_rating_system)
ADD가 다른 방법 대비 낫다는 평가를 받음.
다른 모델보다 더 나은 품질
1보다는 4step이 더 낫다.
비견할만함.
도움이 되는 YouTube 1.
0000