SparseGPT: Massive Language Models Can be Accurately Pruned in One-Shot

 본 문서에서는 2023년 ICML에서 발표된  SparseGPT논문을 소개합니다. 논문에 대한 상세한 정보는 다음과 같습니다. 

  • Title: SparseGPT: Massive Language Models Can be Accurately Pruned in One-Shot
  • Authors: Elias Frantar and Dan Alistarh
  • Conference: International Conference on Machine Learning (ICML) 2023

Pruning of Pretrained LLMs
ChatGPT를 필두로 하는 초거대 언어 모델(Large Language Model, LLM)들은 현재 이전에는 불가능하다고 여겨졌던 많은 일들을 가능케하고 있습니다. 대표적인 예로써, 문맥을 이해하고 대화가 가능한 챗봇(chatbot) 그리고 자연어로 주어진 문장을 바탕으로 소스코드를 생성하는 코파일럿(copilot)등이 있습니다. 이러한 초거대 언어모델들의 높은 성능은 매우 많은 파라미터(parameter)들의 수를 기반으로 거대한 규모의 언어 말뭉치(corpus)를 학습함으로써 얻어지는데, 많은 수의 파라미터들로 인해 추론을 하는데 필요한 계산의 양이 많고, 시간이 오래걸리며, 매우 높은 성능의 하드웨어들이 여러 대 필요하다는 부작용들이 있습니다. 기학습 초거대 언어 모델(Pretrained LLM) 가지치기(pruning) 문제는 이미 학습되어 있는 초거대 언어모델의 우수한 성능은 유지하면서 불필요한 파라미터들을 가지치기함으로써 초거대 언어 모델들의 우수한 성능을 부작용 없이 누리는 것을 목표로 합니다. 추가적으로 본 논문에서는 압축 알고리즘의 효율성을 극대화하는 것을 목표로 하여 1대의 A100 GPU만을 사용해서 초거대 언어 모델을 압축하고자 합니다.

본 논문에서 해결하고자 하는 기학습 초거대 언어 모델 가지치기 문제의 구체적인 정의는 다음과 같습니다.
  • 주어진 정보
    • 학습되어 있는 고성능의 초거대 언어 모델
    • 적은 숫자의 샘플 데이터 셋
    • 원하는 가지치기 된 모델의 크기
  • 목표
    •  주어진 정보들을 활용하여 원하는 크기에 맞고 정확한 언어 모델을 생성하는 것
  • 제한 조건
    • 1대의 A100 GPU 상에서 가지치기 알고리즘이 수행가능하여야 함
A100 GPU 1대만을 이용하여 기학습된 초거대 언어 모델을 가지치기하는 알고리즘을 설계하기 위해서는 그 거대한 크기로 인해서 아래와 같은 문제점들을 해결하여야 합니다.
  • (문제 구성) GPU 1대만을 이용해서 추론을 진행하게 되면 추론 과정이 매우 비효율적이게 되며, 특히 기존의 기학습 언어 모델 압축 문제에서 사용되던 크로스 엔트로피(cross-entropy)를 측정하는 것 또한 매우 비효율적입니다. 어떻게 문제를 구성하면 GPU 1대에서도 효율적으로 문제를 풀어나갈 수 있을까요?
  • (마스크 선택) 가지치기 기법에서는 모델 내의 불필요한 파라미터들을 찾아내고 마스킹하는 마스크 선택을 정확하게 하는 것이 중요합니다. 어떻게 하면 효율적으로 좋은 마스크를 찾을 수 있을까요?
  • (가중치 복구) 마스킹한 파라미터들을 삭제하고 나면 모델의 추론 결과가 변하게 되며, 원본 모델의 우수한 성능을 유지하기 위해서는 살아남은 파라미터들을 업데이트하여 원본 모델의 추론 결과를 복구해야 합니다. 어떻게 하면 효율적으로 가중치들을 업데이트하여 가지치기로 인한 성능 하락을 최소화할 수 있을까요?
Proposed Method
본 논문에서는 앞서 설명한 문제점들을 해결하기 위한 기법인 SparseGPT를 제안합니다. 저자들은 아래와 같은 핵심 아이디어들을 활용하여 SpraseGPT를 설계하였습니다.

  • (문제 구성) 모델의 추론을 끝까지 하지 않고, 추론의 중간 과정의 각 서브레이어(sublayer)에 있는 가중치 행렬들에 대해 순차적으로 가지치기를 수행합니다. 이때 추론을 끝까지 진행하지 않기 때문에, 각 가중치 행렬별로 출력 값을 복원하는 것을 목적함수로 사용합니다.
  • (마스크 선택) 각 가중치 행렬들에 있는 일정 개수의 열(column)들에 대해 마스크 선택을 진행합니다. SparseGPT의 마스크 선택 과정은 하나의 가중치 행렬 내에서 마스크를 찾기 때문에 빠르고, 가지치기와 마스크 선택을 번갈아가며 반복해서 진행하기 때문에 정확합니다.
  • (가중치 복구) OBS[1]라는 기존 알고리즘을 활용하면 하나의 가중치가 사라졌을 때 출력 값을 복원하기 위한 최적의 가중치 행렬의 변화량을 구할 수 있습니다. 하지만, OBS에서는 헤시안(Hessian)의 역수를 반복해서 구하는 것이 매우 시간이 오래걸리는데, 저자들은 이를 극복하기 위한 효율적인 방법들을 제안합니다.
지금부터는 각 아이디어들에 대해 핵심적인 내용들을 위주로 설명하도록 하겠습니다.

Problem Formulation
SparseGPT의 저자들은 모델을 끝까지 추론하지 않고, 일부분만을 추론하면서도 가지치기를 가능하게 하기 위해서, 각 서브레이어에 있는 가중치 행렬들을 순차적으로 가지치기 하도록 문제를 구성하였습니다. 이렇게 문제를 구성할 경우 해당 순서의 가중치 행렬만 GPU에 업로드한 후에 가지치기가 끝나면 내리고, 다음 가중치 행렬을 올리는 식으로 진행할 수 있기 때문에 GPU를 한 대만 사용하더라도 효율적으로 가지치기를 수행할 수 있습니다. 뿐만 아니라, 앞선 서브레이어에 있는 가중치 행렬들 대해 가지치기를 수행한 후에 다음 서브레이어에 있는 가중치 행렬들에 대해 가지치기를 수행하기 때문에 앞선 가지치기 결과를 반영한 정확한 가지치기를 수행할 수 있습니다.
이때, 모델의 추론을 끝까지 하지 않기 때문에, 크로스 엔트로피(cross-entropy)는 계산할 수 없으며, 아래 수식과 같이 각 가중치별로 출력 결과를 복원하는 목적함수를 최소화하도록 문제를 구성합니다. 이때, $\odot$는 원소별 곱(element-wise multiplication)을 의미합니다.
$$\arg\min_{M,\widehat{W}} || WX-(M\odot \widehat{W})X)||_F^2$$
이때 수식에서 $W$는 원본 모델의 가중치 행렬, $X$는 입력 행렬, $M$은 가지치기를 위한 마스크 행렬, 그리고 $\widehat{W}$는 가지치기된 모델의 가중치 행렬을 의미합니다. 이때 해당 문제에서 최적의 가지치기 마스크와 가중치 행렬를 동시에 찾는 것은 어렵기 때문에, 각 단계를 나누어 수행하게 되며, 이어지는 내용에서 각각의 과정을 설명하도록 하겠습니다.

Mask Selection
 마스크 선택(mask selection) 과정은 가지치기를 수행하기 위한 불필요한 요소들을 찾아내고, 마스킹을 하는 과정을 의미합니다. 앞서 저자들이 문제를 각 가중치 행렬에 대해 출력을 복원하는 문제로 설정하였기 때문에, 마스크 선택 또한 각 가중치 행렬들 내에서 이루어지게 됩니다. 이때, 하나의 가중치 행렬에 대해 한 번에 마스크 선택을 수행하게 되면 넓은 범위를 탐색하기 위해 많은 비용이 들고, 너무 적은 개수의 가중치들에 대해 마스크 탐색을 진행하게 되면 그 정확도가 낮아지게 됩니다. 그러므로 저자들은 일정 개수의 열(column)들에 있는 가중치들을 하나의 그룹으로 편성하여 이들에 대해 마스크 선택을 진행합니다. 저자들은 각 열들의 집합에 속한 가중치들에 대해 OBS[1] 알고리즘을 수행함으로써 계산되는 각 가중치별 가지치기 후 추정 오차를 이용하여 추정 오차가 가장 낮은 가중치들을 원하는 압축률에 맞게 선택합니다.
그림 1. SparseGPT의 마스크 선택 예시
그림 1은 SparseGPT가 마스크를 선택하는 과정의 예시를 나타내고 있습니다. 예시에서는 2개의 열을 기준으로 마스크 선택을 진행하고 있으며, 선택된 마스크들은 빈칸으로 표시되어, 이후에 해당하는 위치의 가중치들이 가지치기가 되게 됩니다. 각 마스크 선택 과정에서 예상 오차가 가장 낮은 가중치들을 50%만큼 선택하고 있는 것을 그림을 통해 확인할 수 있습니다. 

Weight Reconstruction
마스크 선택이 끝나고 난 후에는 선택된 가중치들을 삭제하고, 이들로 인한 오차를 최소화하기 위해 남아있는 가중치들을 갱신하여야합니다. OBS[1] 알고리즘을 이용할 경우 각 가중치를 삭제하였을 때 출력의 오차를 최소화하는 가중치 변화량을 계산할 수 있지만, 이는 각 가중치를 삭제하였을 때의 가중치 변화량을 구하기 위해 값비싼 헤시안(Hessian)의 역수를 구하는 연산을 진행하여야 하므로, 매우 오랜 시간이 걸리게 됩니다. 저자들은 이러한 헤시안의 역수를 구하는 과정의 비용을 최소화하기 위하여 다음과 같은 방법들을 활용합니다.
  • 선택된 가중치들을 한 번에 삭제하는 것이 아니라, 열별로 삭제합니다. 가중치 복구 문제는 가중치 행렬의 각 행들에 대한 독립적인 문제로 분리될 수 있으며, 하나의 열에 있는 가중치들을 삭제할 경우 각 행들이 필요로 하는 헤시안의 역수가 같기 때문에 모든 행들에 대해 한 번만 헤시안의 역수를 연산하면 됩니다.
  • 하나의 열에 대해서 가중치들을 삭제하고, 가중치를 복구한 후에 다음 열에 있는 가중치들에 대해 삭제 및 가중치를 복구할 때, 기존 연구[2]에서 제안한 효율적인 헤시안의 역수를 구하는 알고리즘을 활용하여 비용을 줄입니다. 이렇게 하면 값비싼 헤시안의 연산은 각 마스크 선택 시행마다 한 번만 연산하게 됩니다.
그림 2. SparseGPT의 가중치 복구 과정
그림 2는 SparseGPT에서 4개 열에 대한 마스크 $M$이 주어졌을 때 가중치 복구를 진행하는 과정을 나타내고 있습니다. SparseGPT에서는 가장 왼쪽 열부터 선택된 가중치들을 삭제하고, OBS 알고리즘에 따라 가중치를 갱신합니다. 이때, 가중치 복구 문제는 각 행 별로 독립적인 문제이기 때문에, 가중치가 삭제되지 않은 행들에서는 가중치 갱신이 일어나지 않고, 가중치가 삭제된 행들에 대해서만 살아남은 가중치들을 갱신합니다. 또한, 한 열에 대해서 가중치 복구가 완료되면, 계산해두었던 헤시안의 역수를 활용하여 다음 열을 위한 헤시안의 역수를 효율적으로 계산합니다.

SparseGPT는 위 아이디어들을 바탕으로 기학습된 초거대 언어 모델들에게도 적용가능할 정도로 효율적이고, 정확한 가지치기를 수행합니다. 이어지는 내용에서는 SparseGPT에서 수행한 실험 결과들에 대해 설명하도록 하겠습니다.

Experiments

본 문서에서는 SparGPT의 실험 결과 중 가장 중요한 두 가지 결과에 대해 설명하도록 하겠습니다. 두 실험 결과는 모두 OPT[3] 모델을 가지치기한 후에 WikiText2[4] 데이터셋에서 펄플렉시티(perpelxity)를 측정한 결과입니다. 여기서 펄플렉시티란 생성형 언어 모델이 각 단어를 생성하는데 있어서 평균적으로 고민하는 단어의 개수로써 낮을 수록 좋은 값입니다.

그림 3. SparseGPT를 이용한 OPT모델 가지치기 결과
첫 번째 실험은 다양한 크기의 OPT모델에 대해 가지치기를 수행한 실험이며, 그림 3이 해당 실험 결과를 나타내고 있습니다. 그림의 왼쪽에서 오른쪽으로 갈 수록 큰 크기의 OPT모델에 대해 실험한 결과를 나타냅니다. 이때 SparseGPT는 큰 모델로 갈수록 작은 오차를 보이며, 가장 큰 175B 모델에 대해서는 펄플렉시티 증가가 없이 가지치기를 수행합니다. 그림에서 2:4 및 4:8 은 각각 4개의 가중치 묶음당 2개의 0이 있는 가지치기 패턴(pruning pattern)과 8개의 가중치 묶음당 4개의 0이 있는 가지치기 패턴을 의미합니다. 해당 패턴들은 A100 GPU에서 가속을 지원하는 실용적인 패턴들인데, 해당 가지치기 패턴들로 SparseGPT를 수행하여도 175B 모델에 대해 성능 하락이 거의 없이 가지치기를 수행할 수 있는 것을 확인할 수 있습니다. 이를 통해 SparGPT를 이용할 경우 단순히 메모리만 줄어들 뿐만 아니라, GPU를 이용한 추론 속도 향상까지도 가능하다는 것을 확인할 수 있습니다.

그림 4. GPTQ[2]와 SparseGPT를 함께 적용한 결과
두 번째 실험은 SparseGPT와 기존 양자화(quantization) 기법인 GPTQ[2]를 함께 적용하는 실험이며, 그 결과는 그림 4에 나타나 있습니다. 그림에서 갈색 선은 GPTQ를 이용하여 3비트로 압축된 OPT 모델을 의미하고, 분홍색 선은 GPTQ와 SparseGPT를 동시에 활용하여 같은 양의 메모리(memory)를 갖도록 압축한 결과를 의미합니다. 실험 결과, 단순히 GPTQ만을 사용했을 때 보다 GPTQ와 SparseGPT를 함께 사용하였을 때 더욱 우수한 성능을 보였고, 이를 통해 가지치기 알고리즘이 양자화 기법들과도 함께 사용되어 성능을 개선할 수 있음을 보여주었습니다. 

Conclusion
본 문서에서는 ICML 23에서 발표된 SparseGPT: Massive Language Models Can be Accurately Pruned in One-Shot 논문을 소개하였습니다. 해당 논문은 기학습된 초거대 언어모델을 가지치기하는 SparseGPT 알고리즘을 제안하였습니다. SparseGPT는 최대 1,750억개의 가중치를 가지는 초거대 언어 모델에 대해서도 1개의 A100 GPU만을 활용해서 3시간 안에, 정확도 하락 없이 가지치기하는 것을 성공하였습니다. 본 논문은 초거대 언어모델의 비용을 줄이고, 추론 속도를 향상시키는데 활용될 수 있습니다. 대표적인 예로, 많은 사람들이 사용하고 있는 ChatGPT에 적용될 경우 유지 비용을 줄일 수 있을 뿐만 아니라, 더 빠른 응답 속도를 가능케하여 서비스의 품질을 높임으로써 기업과 소비자 모두에게 도움이 될 수 있습니다. 이외에도 초거대 언어 모델이 사용되는 코드 생성, 텍스트 요약, 텍스트 기반 이미지 생성 등의 모든 응용의 효율을 높이기 위해 SparseGPT가 활용될 수 있기 때문에, 그 의미와 쓰임이 매우 큰 논문입니다. 본 논문에 대한 더 자세한 정보는 다음 링크에서 확인할 수 있습니다. (링크)

Reference
[1] Hassibi, Babak, David G. Stork, and Gregory J. Wolff. "Optimal brain surgeon and general network pruning.IEEE international conference on neural networks. IEEE, 1993.
[2] Frantar, Elias, et al. "Gptq: Accurate post-training quantization for generative pre-trained transformers.arXiv preprint arXiv:2210.17323 (2022).
[3] Zhang, Susan, et al. "Opt: Open pre-trained transformer language models.arXiv preprint arXiv:2205.01068 (2022).
[4] Merity, Stephen, et al. "Pointer sentinel mixture models.arXiv preprint arXiv:1609.07843 (2016).