본문 바로가기

AI

[LLM] Fine-tuning시 early stopping 적용하기

1. Early Stopping 사용 배경

LLM 모델을 fine-tuning 하는 과정에서 loss가 줄어들다가 다시 증가해서 2 epoch가 모두 돌아갔을 때 오히려 성능이 떨어지고 있었다. 

이럴 때, 기존 딥러닝 모델 학습 시 early stopping을 콜백 함수를 넣어 지정한  early_stopping_patience 가 지나면 학습을 멈추게 했었는데 사전학습된 모델을 불러와 fine-tuning할 때도 해당 기능을 쓸 수 있는지 궁금했다. 

LLM 모델을 fine-tuning할 때 쉽게 사용하는 클래스가 허깅페이스의 SFTTrainer와 Trainer 클래스인데 
Trainer 클래스에서는 EarlyStoppingCallback 함수를 지원하지만 SFTTrainer에서는 지원하지 않는다. 

 

2. Trainer 클래스를 이용한 EarlyStopping 적용

해당 callback을 추가하기 위해 eval_dataset과 evaluation_strategy를 추가해줘야 한다. 

 
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
 
# Early stopping callback 설정
early_stopping = EarlyStoppingCallback(early_stopping_patience=3)  # 예: 3번의 evaluation 마다 validation loss가 감소하지 않으면 학습 종료

trainer = Trainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=validation_data,
    args=TrainingArguments(
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=1,  #  그라디언트 누적 단계 수 지정. 
        max_steps=500,                           # 배치가 총 몇 번 학습되는지
        num_train_epochs=1,                   # epochs 대신 max_steps을 기준으로 할 수 있습니다.
        learning_rate=1e-4,
        fp16=True,                                   # 모델 매개변수와 그라디언트 16비트 부동 소수점으로 저장하여 학습 가속화
        logging_steps=10,                       # 로그 출력 시 간격 지정
        output_dir=new_model,             # 학습된 모델과 로그 저장할 디렉토리
        optim="paged_adamw_8bit",     # 사용할 옵티마이저 
        evaluation_strategy="steps",
        eval_steps = 50,                        # 해당 스텝 이후 평가
        load_best_model_at_end=True
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),   # 데이터 배치 처리기
    callbacks=[early_stopping]       # Early stopping callback 추가
)
 
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

trainer.train()

# fine-tuning된 모델 저장
model.save_pretrained(new_model)

 

3. Trainer vs. SFTTrainer 비교

Trainer 클래스 같은 경우 처음부터 LLM 모델을 학습할 때 일반적으로 쓰이고 다양한 configuration option이 있어 custom하기가 용이하다. 

SFT를 한다고 하면 많이 보이는 SFTTrainer는 Fine-tuning에 맞춰진 클래스기 때문에 적은 시간과 적은 데이터셋으로 효과적으로 학습이 가능하지만 옵션이 Trainer에 비해서 적다. 

처음부터 대량의 데이터로 LLM 모델을 학습하는 경우에는 여러 옵션을 설정하여 여러 실험으로 성능을 높일 수 있는 Trainer 클래스를 추천하며, 간단하게 미세조정학습만 하는 경우에는 SFTTrainer를 사용하는 것을 추천한다. 

  Trainer SFTTrainer
목적 처음부터 학습  사전학습 모델의 미세조정학습 
커스터마이징 configuration 옵션이 다양하여 커스텀 용이 사용 용이하지만 옵션이 적음
학습 Workflow gradient accumulation, early stopping,
checkpointing, distrubuted training 지원 
간소화된 workflow
필요 데이터 Larger datasets Smaller datsets
메모리 사용 Higher Lower with PEFT and packing optimization
학습 속도 Slower Faster with smaller datasets and shorter times

 

 

참고

- https://medium.com/@sujathamudadla1213/difference-between-trainer-class-and-sfttrainer-supervised-fine-tuning-trainer-in-hugging-face-d295344d73f7