아이공의 AI 공부 도전기

[논문 Summary] SDXL-Turbo (ADD) (2023.11 Arxiv) "Adversarial Diffusion Distillation"

 

     

 

논문 정보

Citation : 2024.02.03 토요일 기준 3회

저자

Axel Sauer, Dominik Lorenz, Andreas Blattmann, Robin Rombach

 

- Stability AI

논문 & Github 링크

Official

 

 

https://stability.ai/news/stability-ai-sdxl-turbo

 

Introducing SDXL Turbo: A Real-Time Text-to-Image Generation Model — Stability AI

SDXL Turbo is a new text-to-image mode based on a novel distillation technique called Adversarial Diffusion Distillation (ADD), enabling the model to create image outputs in a single step and generate real-time text-to-image outputs while maintaining high

stability.ai

 

https://stability.ai/research/adversarial-diffusion-distillation

 

 

Arxiv

https://arxiv.org/abs/2311.17042

 

 

공식 Github

https://github.com/Stability-AI/generative-models

 

 

Huggingface

https://huggingface.co/stabilityai/sdxl-turbo

 

https://clipdrop.co/stable-diffusion-turbo

 

 

논문 Summary

 

0. 설명 시작 전 Overview

 

 

큰 규모로 데이터로 학습시킨 SDXL이라는 foundational image diffusion model에 있어 생성 속도는 항상 고민이었다.

여기서 LCM, LCM-LoRA과 같이 훨씬 적은 step으로도 원 SDXL에 비견할만한 품질의 inference를 사람들은 희망했다.

 

이에 Stability AI의 원 구성원들이 기대에 부응하듯 그 안을 제안했고, 상당한 진전을 보여 논문으로 소개하고자 한다.

 

훨씬 적은 step으로 이미지를 생성할 수 있는 방법론을 제안하는데 이는 자체적 개발보다 기존에 있던 2가지 기술이 크게 녹여들어있다.

 

1) Score Distillation (22.09 Dreamfusion에서 사용한 Score Distillation Sampling (SDS) 참조)

2) Adversarial Loss (GAN 참조)

 

자세한 학습 방법은 아래 Figure 2를 참조하길 바란다.

 

결과적으로 이를 통해 적은 step만으로도 좋은 성과를 달성할 수 있음을 보인다.

 

1. Introduction

 

Diffusion Models이 잘 될 수 있었던 주요 요인 중 하나는 scalability와 반복적인 특성때문이지만, 반복하는 inference 방법은 상단한 sampling step을 요구한다.

이에 반해 GAN은 단일 step에서 생성할 수 있는 장점을 가지는데 이 속도를 본 방법에 적용할 수 있을지를 저자들은 고민했다.

 

결과적으로 저자들은 Adversarial Diffusion Distillation (ADD)라는 새로운 접근법을 제안했는데, 이는 2가지 학습 objective에 대한 조합으로 1~4 sampling step만으로 상당한 품질 결과를 도출할 수 있었다.

 

1) adversarial loss

2) distillation loss (22.09 Dreamfusion에서 사용한 Score Distillation Sampling (SDS) 참조)

 

Inference 동안 저자들은 CFG를 사용하지 않아 메모리 요구량을 감소시켰다.

 

구체적 방법은 3. Method에서 설명하도록 하겠다.

 

2. Background

Distillation 기반 방법들 소개

progressive distillation, guidance distillation 연구분야 소개

Consistency Models

LCMs, LCM-LoRA, InstaFlow

그러나 이들은 blur하고 artifact 생김.

 

Score Distillation Sampling은 Score Jacobian Chaining이라 불리는데 T2I 모델의 지식을 3D 합성 모델에 distill하는데 사용함.

 

최근 Score-based model과 GAN 간의 강한 상관관계가 있음을 보이는데 여기서 영감을 받아 진행함.

 

3. Methods

 

본 방법에서는 

1) pretrained diffusion model의 gradient를 score distillation objective에 사용. (규모있는 GAN에서의 text alignment 성능향상을 위해)

2) 초기 모델의 initialize를 pretrained diffusion model의 weight 사용.

3) GAN 학습에 있어 decoder만 사용하는 것 대신 표준 diffusion model framework 채택.

 

 

3.1 Training Procedure

 

 

3개의 network가 필요

 

1) ADD student (trainable $\theta$ : pretrained UNet-DM weight로 초기화)

2) Discriminator (trainable $\phi$)

3) DM teacher (frozen $\psi$)

 

학습 동안의 forward diffusion process $x_s = \alpha_s x_0 + \sigma_s \epsilon$일 때,

실제 데이터 세트 이미지 $x_0$으로부터 noised data $x_s$가 생성되고 이를 ADD-studen를 통해 $\hat{x}_\theta (x_s , s)$ sample이 생성된다.

훈련동안, N=4이고 zero-terminal SNR 강제

 

Adversarial objective에서는 생성된 sample $\hat{x}_\theta$와 실제 이미지 $x_0$간의 discriminator 결과에 사용

 

Distillation을 위해서 Student sample $\hat{x}_\theta$를 teacher forward process에 넣은 다음, DM-teacher 모델의 결과 $\hat{x}_\psi (\hat{x}_{\theta, t},t)$를 도출한다. 여기서 distillation loss를 진행.

 

 

다만, latent space가 아닌 pixel space에서 distillation loss를 계산하는 것이 더 안정적인 gradient를 내보내기에 그렇게 사용함. 

 

3.2 Adversarial Loss

 

Discriminator에 있어 Stylegan-t에서 제안한 디자인과 훈련 절차를 따른다.

 

frozen pretrained feature network $\mathcal{F}$ : ViT

trainable lightweight discriminator heads $\mathcal{D}_{\phi, k}$ : 다른 layer $\mathcal{F}_k$의 feature에 적용

 

성능 향상을 위해 discriminator는 projection을 통해 추가 정보(text embedding, image) conditioning 진행할 수 있다. 특히, ADD-student에 입력 이미지를 효과적으로 사용되는 것을 받는 것을 위해 사용될 수 있기에 실제로 image embedding $c_{img}$ 추출을 위해 추가 feature network를 사용함.

 

hinge loss 사용 (https://en.wikipedia.org/wiki/Hinge_loss)

 

 

 

R1 : R1 gradient penalty - 128 px 이상에서 효과적임을 확인.

 

3.3 Score Distillation Loss

 

sg : stop-gradient operation

d : distance metric - ADD-student에 의해 생성된 sample $\hat{x}_\theta$와 DM-teacher 결과 $\hat{x}_\psi$간의 불일치 계산. distance function (L2 loss)

 

c(t) : weighting function

2가지 옵션 고려 : exponential weighting, score distillation sampling (SDS) weighting (선택)

 

장점 1) reconstruction target에 대한 직접적 시각화가 가능

장점 2) 연속적인 denoising step을 자연스럽게 실행 가능

 

SDS 변량 평가를 위해 noise-free score distillation (NFSD) objective 평가 진행함.

 

4. Experiments

2 모델 훈련

ADD-M (860M parameters) - Stable Diffusion (SD) 2.1 backbone, ADD-XL (3.1B parameters)

 

실험 수행 해상도 : 512x512 pixels

distillation weighting factor λ = 2.5

R1 penalty strength : $10^{-5}$

 

FID, CLIP score로 정량 평가 진행.

 

4.1 Ablation Study

 

 

Discriminator feature networks (Table 1a) : ViT에 DINOv2 조합이 좋았음

 

Discriminator conditioning (Table 1b) :  $c_{text}$, $c_{img}$ 다 활용하는 것이 효과적

 

Student pretraining (Table 1c) : ADD-Student에 Pretrained generator 사용이 더 좋음

 

Loss terms (Table 1d) : adversarial loss와 결합이 성능향상

 

Teacher type (Table 1e) : 꼭 크다고 좋은 결과를 나타내지는 않음

 

Teacher steps (Table 1f) : 더 많은 step이 꼭 좋은 결과를 도출하지 않음

 

 

4.2 Quantitative Comparison of State of the Art

 

 

ELO Score (https://en.wikipedia.org/wiki/Elo_rating_system)

ADD가 다른 방법 대비 낫다는 평가를 받음.

 

4.3 Qualitative Results

 

다른 모델보다 더 나은 품질

 

 

1보다는 4step이 더 낫다.

 

 

비견할만함.

 

Reference

 

도움이 되는 YouTube 1.

 

 

 

 

 

 

0000 

공유하기

facebook twitter kakaoTalk kakaostory naver band
loading