본 문서에서는 2024년 ICLR에서 발표될 "Accurate Retraining-free Pruning for Pretrained Encoder-based Language Models"논문을 소개합니다. 논문에 대한 상세한 정보는 다음과 같습니다.
- Title: Accurate Retraining-free Pruning for Pretrained Encoder-based Language Models
- Authors: Seungcheol Park, Hojun Choi, and U Kang
- Conference: International Conference on Learning Representations (ICLR) 2024
Retraining-free Pruning of
Language Models
최근 트랜스포머(Transformer) 기반의 기학습 언어 모델(Pretrained Language Model, PLM)들은 사람에 가까운 언어를 구사하며, 단순히 대화를 할 수 있으을 뿐만 아니라 문제 풀이, 코드 작성 등에서도 뛰어난 성능을 보이고 있습니다. 하지만, 이러한 언어 모델의 발전은 언어 모델의 거대화를 바탕으로 이루어 졌습니다. 이로 인해 모델을 서비스하기 위한 유지 비용이 매우 값비싸지고, 환경 오염을 일으키며, 개인의 모바일 기기에 탑재하여 개인화된 서비스를 제공하기가 어려워졌다는 부작용들이 발생합니다.
이러한 부작용들을 없애는 방법 중 하나는 모델 내에서 불필요한 부분을 식별하여 잘라내는 가지치기 기법입니다. 하지만, 일반적인 가지치기 기법들은 모델 내의 불필요한 부분을 식별하고 정확도 손실을 최소화하며 잘라내기 위해 긴 재학습 과정을 필요로 하며, 이로 인해 거대한 언어 모델들에서는 사용이 불가능합니다. 그리하여 거대한 언어 모델들에도 사용 가능하도록 설계된 재학습이 없는 가지치기 기법들이 제안되었지만, 이들은 간소화 된 가지치기 과정으로 인해 성능을 보존하지 못한다는 단점이 있습니다. 그러므로, 기존 가지치기 기법들은 거대한 언어 모델들에 대해 성능 저하 없이 효율적으로 압축하지 못하며, 본 논문에서 이러한 문제를 해결하고자 합니다.
본 논문에서 해결하고자 하는 재학습 없는 기학습 언어 모델 가지치기 문제의 구체적인 정의는 다음과 같습니다. 여기서 재학습이란 기학습 모델의 사전 학습(pretraining) 과정이나, 특정 테스크에 알맞게 파인 튜닝(fine-tuning)하는 정도의 고비용이 드는 학습 과정을 의미하며, 재학습이 없다는 말은 이보다 훨씬 적은 비용이 드는 과정을 의미합니다.
- 주어진 정보
- 학습되어 있는 고성능의 언어 모델
- 적은 숫자의 샘플 데이터 셋
- 원하는 가지치기 된 모델의 크기
- 연산 수(FLOPs)를 기준으로 설정
- 목표
- 주어진 정보들을 활용하여 원하는 크기에 맞고 정확한 언어 모델을 생성하는 것
- 제한 조건
- 경사 하강법등을 이용한 고비용의 재학습 과정이 없어야 함
재학습이 없으면서도 기학습 언어 모델의 성능을 보존할 수 있는 가지치기 알고리즘을 설계하기 위해서는 아래와 같은 문제점들을 해결해야 합니다.
- (중요도 파악 기준) 모델 내에서 불필요한 요소들을 선택하기 위해서는 어떤 기준으로 중요도를 파악해야 할까요?
- (불필요 요소 파악) 파악된 중요도를 바탕으로 어떻게 하면 모델 내에서 가지치기할 불필요한 요소들을 선택할 수 있을까요?
- (정확도 손실 최소화) 불필요한 가지치기 요소를 삭제하면서 생기는 정확도 손실을 어떻게 하면 재학습 없이 최소화할 수 있을까요?
Proposed Method
본 논문에서는 앞서 설명한 문제점들을 해결하기 위한 기법인 Kprune를 제안합니다. 저자들은 아래와 같은 핵심 아이디어들을 활용하여 Kprune를 설계하였습니다.
- (지식량 측정) 모델의 각 요소별로 담겨 있는 지식의 양을 측정하고, 이를 바탕으로 모델의 중요도를 파악합니다.
- (지식을 보존하는 마스크 탐색) 모델의 각 요소 종류 별 특징, 지식의 종류 별 특징을 고려하여 가지치기할 요소를 탐색합니다. 또한, 모델 전체에서 가지치기 할 요소를 탐색함으로써 각 서브레이어(sublayer) 별로 중요도를 반영하여 가지치기할 요소의 양을 조정합니다.
- (지식을 보존하는 가중치 조정) 가지치기 대상으로 선정된 요소들을 삭제한 후에 0.1초 내에 수행되는 매우 효율적인 가중치 조정 과정을 수행합니다.
그림 1. Kprune의 동작 과정 예시
지식량 측정 (Knowledge Measurement)
기학습된 모델의 성능을 보존하기 위해서는 방대한 양의 데이터를 통해 학습된 기학습된 모델의 지식을 보존하는 것이 중요합니다. 특히 재학습이 없는 가지치기의 경우 지식이 한 번 훼손되면 이를 다시 복구할 수 없기 때문에 이러한 지식의 양을 측정하는 과정은 더욱 필수적입니다. Kprune에서는 기학습된 모델의 지식을 예측 지식(predictive knowledge)와 표현 지식(representational knowldge)으로 분류하여 각 요소별로 두 지식의 양을 측정합니다. 지식 측정은 모델에서 각 요소를 삭제하여 생성되는 모델와 삭제 전의 기학습 모델의 지식의 차이를 재는 방식으로 이루어집니다. 그림 1 (a)를 보면 가지치기 대상인 모델의 요소들 (어텐션 헤드(attention head) 및 뉴런(neuron))에 대해서 두 종류의 지식의 양을 측정한 것을 알 수 있으며, 붉은 색이 진할 수록 예측 지식의 양이 많은 것이고, 푸른색이 진할 수록 표현 지식의 양이 많을 것을 알 수 있습니다. 이때, 첫 번째 서브레이어의 경우에는 이미 가지치기가 종료되어 지식의 양이 측정되지 않았습니다.
지식을 보존하는 마스크 탐색
(Knowledge-preserving Mask Search, KPMS)
두 번째 단계는 측정된 지식의 양을 바탕으로 가지치기 할 대상을 선정하는 일입니다. 현재 두 가지 종류의 지식의 양이 각각 두 종류의 모델 요소들에 대해 측정되었기 때문에, 요소의 종류별 특징과 지식의 종류 별 특징을 고려하여 가중치 합을 통해 모델의 중요도를 측정하게 됩니다. 그림 1 (b)에서는 이러한 종합적인 중요도 점수를 보라색으로 표현하였으며, 보라색이 짙을 수록 중요한 요소가 됩니다. 다음으로 이러한 중요도 점수를 이용하여 요소들을 정렬한 후에 모델의 압축 조건에 따라 가지치기할 요소들을 선정하게 됩니다. 그림 (b)의 (2)에서는 현재 2번째 서브레이어의 뉴런들 2개와 3번째 서브레이어의 어텐션 헤드 하나가 대상으로 선정된 것을 알 수 있습니다. 이때, 중요한 점은 우리는 가지치기 대상인 두 번째 서브레이어 뿐만 아니라 다른 서브레이어의 중요도까지도 측정하기 때문에 모델 전체에서의 중요도를 바탕으로 대상 서브레이어에서 가지치기할 양을 정할 수 있습니다. 만약 두 번째 서브레이어가 더욱 중요했다면 2개의 뉴런들이 선택되지 않고, 하나만 선택되거나, 혹은 선택되지 않았을 것입니다. 이러한 과정을 통해 Kprune은 각 서브레이어별로 중요도에 따라 가지치기 양을 올바르게 조정할 수 있습니다.
지식을 보존하는 가중치 조정
(Knowledge-preserving Weight-tuning KPWT)
다음 단계는 선택된 요소들을 삭제하고 가중치를 조정하는 단계입니다. 앞선 (b) 단계에서 선택된 요소들 중 대상 서브레이어에 있는 2개의 뉴런들을 삭제해줍니다. 또한 이로 인한 지식 손실을 복구하기 위하여 기학습 모델의 출력을 최대한 복원할 수 있도록 남아있는 뉴런의 가중치를 조정해줍니다. 본 논문에서는 해당 과정을 Pytorch에서 제공하는 선형 솔버(linear solver)를 이용하여 풀 수 있도록 설계하였으며, 그 결과 약 0.1초 안에 가중치 조정이 완료됩니다. 가중치 조정을 통해 모델의 지식을 복구한 후에는 다음 서브 레이어로 넘어가 해당 과정을 반복해나가게 되며, 이때 모델의 연산 수(FLOPs) 예산을 대상 레이어에서 가지치기 되지 않고 남은 요소의 연산 수 만큼 삭감한 후에 진행합니다. (a)부터 (c)까지의 과정을 가장 아래의 서브레이어에서부터 가장 끝의 서브레이어까지 반복하여 진행하게 되면 Kprune을 이용한 가지치기 과정이 종료되게 됩니다.
Experiments
본 문서에서는 Kprune의 실험 결과 중 가장 중요한 두 가지 결과에 대해 설명하도록 하겠습니다. 두 실험 결과는 모두 BERT 및 DistilBERT 모델을 가지치기한 후에 GLUE 밴치마크(benchmark)를 이용하여 평가한 결과입니다.
첫 번째 실험은 다른 재학습을 하지 않은 가지치기 알고리즘들과 다양한 압축률에서 가지치기된 모델의 정확도를 비교하는 실험입니다. 현재 각 그래프에서 오른쪽으로 갈 수록 압축률이 높아지는 것이고, 위로 갈 수록 정확도가 높은 것인데, 모든 압축률 구간에서 Kprune이 다른 기법들보다 높은 정확도를 보이는 것을 확인할 수 있습니다. 특히 이러한 간극은 압축률이 높아질 수록 급격히 커지는 경향이 있으며 SQuAD 1.1 데이터셋에서 최대 58.02%p까지도 더 높은 경향을 보이는 것을 확인할 수 있습니다.
그림 4. 가지치기 효율성 비교 실험 결과
두 번째 실험은 Kprune과 다른 가지치기 알고리즘들의 가지치기 효율성을 비교하기 위해 가지치기 비용 및 가지치기된 모델의 정확도를 함께 비교해보았습니다. 가지치기 비용은 가지치기를 하기 위해 소요된 시간으로 측정하였으며, Kprune의 효율성을 보여주기 위해 재학습을 필요로 하는 가지치기 기법들인 EBERT와 DynaBERT-w, DynaBERT-d 도 실험에 포함하였습니다. 해당 실험에서는 왼쪽 위의 BEST에 가까울 수록 가지치기 알고리즘의 효율성이 높은 것인데, Kprune이 해당 지점에 가장 가까운 결과를 보인 것을 알 수 있습니다. 특히, 다른 알고리즘을 이용하여 Kprune과 비슷한 정확도를 내기 위해서는 최대 422배까지도 더 많은 시간을 투자해야하는 것을 알 수 있으며, 이를 통해 Kprune이 매우 효율적인 가지치기 알고리즘임을 알 수 있습니다.
그림 4. 가지치기 효율성 비교 실험 결과
두 번째 실험은 Kprune과 다른 가지치기 알고리즘들의 가지치기 효율성을 비교하기 위해 가지치기 비용 및 가지치기된 모델의 정확도를 함께 비교해보았습니다. 가지치기 비용은 가지치기를 하기 위해 소요된 시간으로 측정하였으며, Kprune의 효율성을 보여주기 위해 재학습을 필요로 하는 가지치기 기법들인 EBERT와 DynaBERT-w, DynaBERT-d 도 실험에 포함하였습니다. 해당 실험에서는 왼쪽 위의 BEST에 가까울 수록 가지치기 알고리즘의 효율성이 높은 것인데, Kprune이 해당 지점에 가장 가까운 결과를 보인 것을 알 수 있습니다. 특히, 다른 알고리즘을 이용하여 Kprune과 비슷한 정확도를 내기 위해서는 최대 422배까지도 더 많은 시간을 투자해야하는 것을 알 수 있으며, 이를 통해 Kprune이 매우 효율적인 가지치기 알고리즘임을 알 수 있습니다.