Distilling Step-by-Step! Outperforming Larger Language Models with Less Training Data and Smaller Model Sizes
Cheng-Yu Hsieh, Chun-Liang Li, Chih-Kuan Yeh, Hootan Nakhost, Yasuhisa Fujii, Alexander Ratner, Ranjay Krishna, Chen-Yu Lee, Tomas Pfister
arXiv:2305.02301
ACL 2023
요약
LLM에서 label 뿐 아니라 rationale까지 추출하여 소형 모델을 multi-task로 학습시키면, 기존 finetuning이나 distillation보다 훨씬 적은 데이터와 작은 모델로도 LLM을 능가할 수 있다는 것을 보인 논문이다.
문제 정의

540B PaLM 같은 대규모 LLM을 서빙하려면 수백 GB의 GPU 메모리와 특수 인프라가 필요하다. 현실적으로 대부분의 서비스 팀은 이런 자원을 감당할 수 없으므로, 작은 task-specific 모델을 배포하는 것이 일반적이다.
작은 모델을 만드는 기존 방법은 두 가지다.
Standard finetuning은 사람이 만든 레이블로 사전학습 모델을 미세조정하고, Standard distillation은 LLM이 생성한 pseudo label로 학습한다. 문제는 두 방법 모두 LLM 수준의 성능에 도달하려면 대량의 학습 데이터가 필요하다는 점이다.
이 논문의 핵심 통찰은 관점의 전환에 있다. LLM을 단순히 "노이즈 있는 레이블 생성기"로 보는 대신, "추론할 수 있는 에이전트"로 바라본다. LLM은 Chain-of-Thought(CoT) 프롬프팅으로 답뿐만 아니라 그 답에 도달한 이유(rationale)까지 생성할 수 있고, 이 rationale에는 소형 모델이 많은 데이터를 통해서야 겨우 학습할 수 있는 task knowledge가 담겨 있다.
핵심 아이디어 - Rationale을 Student의 추가적인 감독 신호로 쓰자

Step 1: LLM에서 Rationale 추출
Few-shot CoT 프롬프팅으로 LLM(540B PaLM)에 입력을 넣으면, 레이블과 함께 자연어 rationale이 생성된다.
예를 들어, "골프 장비를 들고 있는 사람이 어디로 갈까?"라는 질문에 대해 LLM은 "골프에 사용되는 곳이어야 하고, 선택지 중 클럽만 골프에 해당한다"라는 rationale과 함께 "club"이라는 답을 출력한다.
Step 2: Multi-task Learning으로 소형 모델 학습
추출한 rationale을 활용하는 방법이 핵심이다. Rationale을 추가 입력으로 넣는 방식은 추론 시에도 LLM이 필요하므로 배포 문제가 해결되지 않는다. 또한 rationale과 label을 단일 시퀀스로 이어 붙여 하나의 타겟으로 학습하는 single-task 방식은 오히려 성능을 해칠 수 있다.
이 논문은 multi-task learning을 택한다. 입력에 task prefix([label] 또는 [rationale])를 붙여, 하나의 모델이 레이블 예측과 rationale 생성을 동시에 학습한다.
$$L=L_{label}+λ·L_{rationale}$$
핵심은 추론 시에는 [label] prefix만 사용하므로 LLM 없이 소형 모델만으로 예측이 가능하다는 점이다. Rationale 생성 태스크는 학습 시에만 보조 감독 신호로 작용하여, 모델이 입력과 출력 사이의 관계를 더 깊이 이해하도록 유도한다.
실험 결과
4개 NLP 벤치마크(e-SNLI, ANLI, CQA, SVAMP), Teacher LLM은 540B PaLM, Student는 T5(220M/770M/11B)로 실험했다.
데이터 효율성: 220M T5 기준, Distilling Step-by-Step은 e-SNLI에서 전체 데이터의 12.5%만으로 standard finetuning(100% 데이터)을 능가했다. 평균적으로 50% 이상의 데이터 절감을 달성했다.
모델 크기 절감: 540B PaLM의 Few-shot CoT를 e-SNLI에서 220M T5로 능가했고, ANLI에서는 770M T5(700배 이상 작은)로 능가했다. 레이블 없이 unlabeled 데이터만 사용해도 3/4 데이터셋에서 LLM을 넘어섰다.
동시 절감: ANLI에서 770M T5 + 80% 데이터만으로 540B PaLM을 능가한 반면, standard finetuning은 100% 데이터를 써도 따라잡지 못했다.
Ablation: Multi-task 학습이 single-task(rationale+label 이어 붙이기) 대비 일관되게 우수했다. Single-task는 ANLI, CQA에서 standard finetuning보다 오히려 성능이 떨어지기도 했다. 또한 20B GPT-NeoX에서 추출한 rationale로도 성능 향상이 있었지만, 540B PaLM 대비 향상 폭이 작아 rationale 품질이 중요함을 보여준다.
강점 및 한계
강점
LLM의 추론 능력을 "학습 신호"로 재활용한다는 발상
Hinton의 original KD가 soft target으로 오답 간 확률 구조를 전달했다면, 이 논문은 자연어 rationale로 명시적 추론 과정을 전달한다.
Multi-task 프레임워크 덕분에 추론 시 LLM이 불필요하다는 점이 실용적이다.
한계
Few-shot CoT 프롬프트 구성에 사람의 개입(약 10개 예시)이 필요하며, LLM이 복잡한 추론에서 한계를 보일 경우 rationale 품질이 저하될 수 있다. Teacher의 bias가 student에 그대로 전이되는 문제도 남아 있다.
실제로 지금 하고 있는 연구 과제에 이 논문의 방법론을 직접 적용해보았는데, teacher에서 나온 rationale의 품질과 few-shot 프롬프트의 품질에 크게 의존하는 경향을 보였다.
Teacher가 내뱉는 rationale의 퀄리티가 달라지면 student에 distill 하였을 때의 성능도 큰 변동이 있었다.
그리고 논문에서는 540B PaLM, 그리고 20B GPT-NeoX을 사용했지만, 실제로 7.8B 모델을 teacher model로 사용하여 실험해보았을 때 student model을 그냥 fine-tuning 하는 것보다 더 정확도가 떨어지는 모습을 보였다.
그럼에도 multi-task loss를 통해 rationale을 명시적으로 학습 신호에 포함시키는 것만으로, 추론 시점에서 모델 파라미터에 추론 능력이 내재화된다는 점을 보여준 논문이라고 생각한다.