Published on

LLaMA 2 Open Foundation and Fine-Tuned Chat Models 논문 리뷰

Authors
  • avatar
    Name
    JeongYun Lee

해당 글은 'LLaMA 2: Open Foundation and Fine-Tuned Chat Models' 논문 리뷰이며, Udemy의 '모두를 위한 대규모 언어 모델 LLM Part 1Section3: LLaMA2 모델 리뷰' 강의를 적극 참고하여 작성하였습니다.

2023년 7월에 💫같이 등장한 LLaMA2는 Transformer 기반의 LLM으로, LLaMA1 출시 후 5개월 만에 상당한 성능 향상을 이뤄냈습니다. LLaMA가 큰 주목을 받은 이유는 바로 모델의 가벼움 대비 뛰어난 성능과 누구나 상업적 사용이 가능하도록 한 오픈소스 정책 때문입니다. 학습과 구동할 때의 compute budget을 고려하여 GPT3 175B(1750억 파라미터)에 비해 1/10 이상 사이즈를 축소시켰으며, 성능 평가를 했을 때 여러 벤치마크에서 더 나은 성능을 보인다고 합니다. 또한 LLaMA 6.7B 모델은 단일서버(V100)에서도 실행 가능하여 보다 접근성 있는 모델이라고도 밝혔습니다. 무제한으로 늘어가는 모델의 크기와 컴퓨팅 리소스 관련해서 늘 이슈가 있었던 LLM 시장에 반가운 모델이 등장한 것이죠.

LLaMA2 논문은 총 8파트로 구성됩니다.

  1. Abstract
  2. Introduction
  3. Pre-Training
  4. Fine-Tuning
  5. Safety
  6. Discussion
  7. Related Work
  8. Conclusion

아래 그림은 chat 모델을 학습하는 Pre-Training과 Fine-Tuning 과정에 대해 전체적으로 정리한 내용입니다. 본 리뷰에서는 이 학습 순서에 따라서 하나씩 살펴보도록 하겠습니다.

LLaMA2의 pre-training과 fine-tuning 과정

1. Pre-Training

LLaMA2의 Pre-Training은 LLaMA1과 거의 유사하게 진행되어 해당 논문에서는 이전 버전과 차이점에 초점을 두어 설명합니다.

  1. 더 많은 학습 데이터 사용 학습 데이터는 인터넷에서 공개적으로 수집할 수 있는 데이터를 사용했고, Meta의 데이터(Instagram, Facebook 등의 개인 정보)는 사용하지 않았다고 합니다. LLaMA1에 비해 40% 증가한 2조개의 학습 토큰을 사용했는데, 이때 토큰의 의미는 말뭉치(corpus)의 길이가 2조개라는 의미가 아닌, 전체 epoch에서 사용한 토큰을 다 합친 값입니다. 예를 들어 학습 epoch가 2일 때, 학습 corpus는 1조개가 되는거죠. 한 가지 특이한 점은, 학습 데이터에서 모든 유해 컨텐츠를 필터링 하진 않았다는 것입니다. 일반적으로 모델의 편향성과 유해성을 이유로 학습 데이터셋을 robust하게 정제하는 경향이 있는데 LLaMA의 경우 모델이 기본적으로 유해한 내용도 전부 학습하게 한 뒤, Fine-Tuning 과정에서 해당 내용은 유해하다는 것을 다시 학습시켜주는 방식을 선택했습니다. 또한 신뢰할 수 있는 출처의 데이터는 업샘플링하여 더 많이 학습되도록 만들어서 hallucination을 최대한 줄이도록 노력했다고 합니다. 학습 데이터의 언어는 다양하지만, 영어가 89.7%를 차지하고 한국어는 0.06%가량 밖에 없어서 실제로 LLaMA를 사용해보면 한국어 성능이 매우 떨어지는 것을 확인할 수 있습니다.
  2. context 길이를 2배로 LLaMA2는 기존의 context 길이인 2048의 2배인 4096입니다. context 길이는 한 번의 input으로 넣어 주는 문장의 길이를 의미합니다. context의 길이가 길수록 앞뒤 문맥 파악과 포함할 수 있는 정보의 양은 많아져 성능이 향상되지만, 학습은 어려워집니다.
  3. Grouped Query Attention (GQA) GQA는 구글에서 개발한 기술로, MHA(Multi-Head Attention)와 MQA(Multi-Query Attention)의 장점을 결합하여 추론 속도를 빠르게 하면서 성능은 유지할 수 있는 방식입니다.
GQA와 MHA, MQA 비교

GQA는 Multi-head와 Mul-Query의 중간적 방법으로, MHA보다 디코딩 속도를 향상시키며 하며 동시에 MQA의 학습의 불안정성을 보완할 수 있는 방법입니다. Query의 그룹을 만들고 각 그룹마다 Key와 Value를 두는 방식입니다. GQA는 이미 학습된 모델에도 적용할 수 있으며(pre-trained된 모델에 추가로 GQA로 학습시키는 방법) LLaMA2에서는 70B모델에서만 적용하여 성능이 뛰어난 것을 확인할 수 있습니다.

MHA와 MQA란?

Multi-Head Attention은 Attention score를 계산할 때 Query, Key, Value를 각각 head의 수 만큼 병렬로 나누어 계산을 한 뒤, 마지막에 concatenate하여 하나로 합치는 방식입니다. 여러 부분(head)에서 결과를 나누어 도출하여 서로 상호보완하기 때문에 기존에 비해 성능이 향상될 수 있다고 합니다. Multi-Query-Attention은 Query에 대해서는 원래 헤드 수를 유지하지만 Key와 Value에 대한 헤드는 하나만 갖는 것입니다. 이 방식은 메모리를 효율적으로 사용하여 디코더에서 상당한 속도 향상을 이룰 수 있으나 성능저하의 위험이 있습니다.

GQA와 관련된 자세한 내용은 다음 논문을 참고하세요.

2. Fine-Tuning: Supervised Fine-Tuning (SFT)

SFT는 입력 prompt값과 출력 response 값 쌍을 활용하여 학습하는 방식입니다. LLaMA2에서는 SFT 과정에서는 데이터의 양보다 질이 중요하다는 선행연구(LIMA:) 에 따라서 자체적으로 데이터를 생산하여 학습을 진행합니다. 약 27540개의 고품질 데이터를 직접 생산하여 학습했을 때, 기존에 구할 수 있는 수백만개의 저품질 데이터로 학습하였을 때 보다 성능이 향상되었다고 합니다.

구축한 prompt와 response데이터 쌍은 모델의 시퀀스 길이를 적절히 맞추어 학습 효과를 높이기 위해 (너무 짧은 문장이 들어가면 문맥파악이 어려워 학습 효과가 낮기 때문) prompt<SEG>response와 같은 형식으로 concatenate 해서 학습의 input으로 넣어줍니다. 이후 사용자 프롬프트 토큰에 대한 loss를 0으로 만들어 결과적으로 정확한 답변을 생성하도록 하는 것이 목표이므로 실제 정답값에 대해서만 backpropagation을 진행합니다.

'Auto-regressive'란? ✅ (p9)
자기 자신을 입력으로 하여 스스로 예측하는 모형. 현재 시점까지 생성한 ouput을 사용하여 다음 시점의 output에 대한 예측을 수행하는 모델

SFT와 관련된 추가적인 학습은 다음 논문을 참고하세요

3. Fine-Tuning: Main(Aligned) Model

LLaMA2는 pretraining부터 SFT를 지나는 chat을 위한 전체적인 main(aligned) model과 fine-tuning을 위한 reward 두개의 모델이 존재합니다. (두 모델 학습의 순서는 설명 편의상 부여한 것이며 논문 상 정확한 언급은 없습니다) SFT를 진행한 뒤 생성된 모델을 활용하여 각 prompt 당 K개의 response를 출력합니다. 이후 . 이 출력값 각각에 대해 reward model을 통해서 reward score를 생성하고 이 중 가장 높은 값을 갖는 response를 선택하는 과정이 Reject Sampling입니다.

다음 4단계에서는 위 과정에서 필요한 Reward Model을 만드는 방법에 대한 설명입니다.

4. Fine-Tuning: Reward Model

초기 reward 모델은 오픈소스로 공개된 Human Preference Dataset과 Meta에서 구축한 데이터를 이용하여 학습을 진행합니다.

  1. helpfulness: meta helpfulness50%, meta safety25%, open source 25%
  2. safety: 90%는 meta + Anthropic Harmless, 10%는 meta safety

temperature 파라미터를 조절하여 두개의 모델을 만들고 각 모델에 동일한 prompt를 넣어서 response를 출력합니다. 이후 사람은 두개의 출력값에 대하여 safety와 helpfulness 두 가지 관점에 따라서 response를 평가하게 됩니다. 평가 선지는 significantly better, better, slightly better, negligibly better 네 단계가 있습니다. 예를들어, '오늘은 비가오는데, 저녁을 추천해줘'라는 prompt에 대하여 modelA는 '비가 오는 날에는 따뜻한 국물 요리가 좋겠어요. 라면 어떠신가요?' 라는 답변을 하고 modelB는 '김밥을 추천합니다. 다양한 재료가 들어가 건강에 좋은 음식이에요.' 라는 답변을 했다면 safety 관점에서는 두 답변 모두 안전성 측면에 차이가 없으므로 negligibly better을 선택할 수 있고, helpfulness 관점에서는 modelA의 답변이 더 적절하지만 mobelB의 답변도 틀리진 않기 때문에 better 정도를 선택할 수 있습니다. LLaMA2는 안전성을 강조하기 때문에 safety 관점에서는 한 가지 평가를 더 진행합니다. 동일하게 출력된 두 response를 3가지 평가 선지(one is safe but the other is unsafe, both safe, both unsafe)에 대하여 평가하고 both unsafe라는 답변을 받은 reponse는 학습 데이터에서 제거합니다. human annotators의 helpfulness와 safety 각각의 이진 분류 라벨에 따라서 pretrained과 SFT를 마친 모델을 활용하여 이진분류 학습을 진행합니다. 이 학습은 pretraining의 학습 방식과 동일하지만, 다음 토큰 예측을 위한 분류 헤드가 scalar reward score를 출력하기 위한 regression 헤드로 대체된다는 차이가 있습니다. 출력된 chosen scalar score와 Rejected scalar score를 활용하여 기존에 4단계를 나눠 답변을 받은 것에 대한 margin을 부여하여 ranking loss를 계산합니다. margin은 두 문장 사이의 거리, 즉 얼마나 비슷한지에 대한 값이라서 두 문장의 차이가 크지 않다면 margin은 작습니다. 즉, model이 confidence를 쫌 더 확실하게 가져가게 하도록 하기 위해서 도입한 것 입니다. 아래 표에서 margin을 고려했을 때 성능이 더 좋다는 것을 확인할 수 있습니다.

margin에 따른 model accuracy 차이

Meta helpfulness and safety data는 reward modeling을 위해서 구축하여 domain적 한계가 있을 수 있으나, 100만개 이상의 이진비교를 실시했고 전체 데이터 수, 대화당 평균 턴 수, 예시당 평균 토큰 수, 프롬프트와 응답당 평균 토큰 수 등에서 타 데이터 셋에 비해 우수하다고 합니다.

Meta에서 구축한 helpfulness and safety data 통계

5. Reinforce Learning Human Feedback (RLHF)

K개의 reward score를 생성하기 위해 4에서 생성한 Reward Model에 K개의 response를 넣어줍니다.

6. RLHF: Reject Sampling

prompt 당 출력된 K개의 reward score은 safety reward score와 helpfulness reward score 각각 생성됩니다. 이 중 가장 높은 점수의 답변만 sampling 해주는 단계가 Reject Sampling입니다.

이 과정은 70B 모델의 경우에만 실행하고 이 보다 작은 모델의 경우 70B모델의 생성결과를 SFT 방식으로 distillation(가장 큰 모델을 통해서 도출된 결과값을 작은 모델에서 사용한다는 의미로 해석) 해준다고 합니다. (distillation의 성능과 관련된 연구는 향후 테스크로 둠)

7. RLHF: Iterative Fine-Tuning

기존의 RLHF는 선형적으로 실행되는 반면, LLaMA2에서는 반복적으로 학습하여 SFT에서 Reward Model의 결과를 재학습할 수 있도록 합니다. 6번에서 sampling해준 데이터를 다시 SFT 학습 데이터로 넣어줘서 2~6까지의 학습 과정을 반복하는 것입니다. LLaMA2에서는 이 반복을 5번 해줬다고 합니다.

학습을 반복하여 진행했을 때 성능이 좋아지는 지에 대한 판단은 Model-based와 Human Evaluation 두 가지로 진행하여 판단합니다. 우선 Model-based Evaluation은 학습을 반복을 할 수록 reward model의 값이 향상되는 지에 대해 확인하기 위해 reward model의 score과 3명의 annotator가 책정한 점수 간의 상관관계를 파악합니다. safety와 helpfulness에 대하여 각각 1586개, 584개의 별도 테스트 데이터를 생성해서 진행했을 때 아래 그래프에서와 같이 양의 상관관계를 확인할 수 있습니다. fig 29 이렇게 검증한 reward model을 사용하여 반복한 RLHF V1~V5까지의 결과는 ChatGPT와의 상대 비교를 통해 확인하는데, helpfulness와 safety는 모두 itertation을 할 수록 성능이 향상되었다고 합니다. 두 번째는 Human Evaluation입니다. 4000개의 single, multi turn dialog를 준비하고 model로부터 response를 생성하도록 한 뒤 3명의 annotator가 책정한 점수를 확인해보면 ChatGPT와 비슷한 수준의 결과를 얻었다고 합니다. (이러한 평가 방법에 대한 한계점이 있음을 언급하기도 합니다)

기존의 RLHF 학습 방법과 관련된 내용은 아래 논문을 참고하세요.

8. RLHF: Proximal Policy Optimization

Reject Sampling을 5번 해준 뒤, 마지막 단계에서는 anootated 된 사람의 reward score를 기반으로 최적화하는 PPO를 진행합니다. 이때 사용하는 reward는 출력된 safety reward와 helpfulness reward 중 하나의 reward만 선택해주는 과정을 거칩니다. 두 reward의 값을 비교했을 때, safety reward의 값이 0.15보다 크면 helpfulness reward를 선택하여 정확도를 더욱 높이고, 그렇지 않은 경우 safety reward를 선택하여 안전성을 높이는 데 기여하도록 합니다.

Proximal Policy Optimization에 대해서 정확히 이해하기 위해서는 강화학습 전반에 대해 이해가 필요합니다. PPO와 관련된 세부적인 내용은 아래 논문을 참고하세요.

9. Ghost Attention (GATT)

Ghost Attention은 LLaMA2의 contribution에 해당되는 기술로, Multi-turn 채팅 상황에서 Instruction이 지속되도록 하는 것을 의미합니다. 이때 Instruction은 논문 예를 따르면, '질문에 대해 이모티콘으로만 답해줘' 라는 지시가 있으면 prompt와 response가 계속 지속되어도 끝까지 문자로 답하지 않고 이모티콘으로 답하도록 유지하는 것을 의미합니다. Ghost Attention의 방법은 Context distillation에서 착안하여 개발하였고 데이터 생성 과정에서만 instruction을 user의 입력마다 삽입해서 supervised fine-tuning 과정에서 진행하게 됩니다.

참고자료

  • 리뷰 Meta AI의 Small Gaint Model: LLaMA(Large Language Model Meta AI), daewoo kim [바로가기]
  • 오픈소스 LLM의 패러다임 전환: Meta AI의 LLaMA2 — (1) overview, daewoo kim [바로가기]
  • 리뷰 Meta AI의 Small Gaint Model: LLaMA(Large Language Model Meta AI), daewoo kim [바로가기]
  • 리뷰 Meta AI의 논문 LIMA(Less Is More for Alignment): 결국 LLM의 Pre-training이 가장 중요하다?, daweoo kim [바로가기]
  • 논문 리뷰 LLaMA2, DAJE [바록가기]
  • Paper Review LLaMA 2: Open Foundation and Fine-Tuned Chat Models, Jaehee Kim [바로가기]
  • 인공지능,머신러닝,딥러닝 (심화) LLaMA2, 컴달인 [바로가기]
  • 인공지능,머신러닝,딥러닝 (심화) LLaMA2 - Ghost Attention, 컴달인 [바로가기]
  • 라마2에 적용된 추론 속도 향상 기술인 GQA(Grouped Query Attention)에 대해, singleheart [바로가기]
  • Multi-Query Attention Explained, Florian [바로가기]
  • Auto Regressive Models, ratsgo's blog [바로가기]