PyTorch Lightning
- 배경
- 구현하는 코드의 양이 늘어나며 코드의 복잡성이 증가하고 이에 따라 다양한 요소들이 복잡하게 얽히게 됨
- 이런 요소들은 서로 강하게 관계성을 가지게 되어, 한 부분을 변경하면 다른 부분에도 영향 - PyTorch에 대한 high-level 인터페이스를 제공하는 오픈소스 라이브러리
- High-level 인터페이스 : 복잡한 시스템이나 프로그램을 사용자에게 더 단순하고 이해하기 쉽게 만들어주는 인터페이스를 의미
- 딥러닝 모델 구축의 코드 템플릿으로써 기능을 하여 코드를 작성할 때 좀 더 정돈되고 간결화된 코드를 작성
- 코드 템플릿 : 특정 프로그래밍 작업을 간소화하거나 반복적인 코드 작성을 줄이기 위해 미리 정의된 코드 블록이나 구조를 의미 - 코드의 추상화 및 하드웨어 호출 자동화
- 코드의 추상화 : 복잡한 로직을 간단한 인터페이스 뒤에 숨기는 것을 의미
- 프로그래머는 내부 로직에 대해 신경쓰지 않고, 필요한 기능을 쉽게 사용
- PyTorch는 model, optimizer, training loop 등을 전부 따로따로 구현
- LightningModule 클래스 안에 모든 것을 한 번에 구현
- 클래스 내부의 메서드명은 모두 PyTorch Lightning에서 요구하는 대로 똑같이 써야 하며, 그 목적에 맞게 코딩
- PyTorch Lightning에서는 학습에 필요한 하드웨어(CPU, GPU)를 자동으로 호출하여 사용해 - 다양한 콜백 함수와 로깅
- 다양한 내장(built-in) 콜백 함수를 지원하며, 이를 사용해 딥러닝 학습과 연관된 특정 기능을 편리하게 사용
- 다양한 로깅 도구를 지원하여 로깅해야하는 값을 편리하게 기록하고 TensorBoard, WandB 등 모니터링 툴을 쉽게 사용 - 16-bit precision
- 일반적으로 딥러닝 모델에서 실수를 표현하는 비트 수는 32-bit인데, 이를 줄여 모델의 계산 속도 향상과 메모리 사용량을 줄이고자 한 아이디어가 16-bit precision
- 최근 딥러닝 연구에 사용되는 모델의 크기는 대체로 큰 경향성
- 이럴 경우, 모델 전체를 GPU에 로드하여 학습하고 사용하기에 제한
-> PyTorch Lightning에서는 16-bit precision과 같은 복잡한 기능 또한 옵션으로 추가하여 편리하게 사용
LightningModule
- LightningModule 클래스를 상속받아 모델의 구조, 손실 함수, 학습 및 평가 방법과 최적화 알고리즘을 클래스에 선언
- 모델 구조와 학습 로직을 함께 클래스로 선언하여, 코드의 구조를 더욱 명확하게 만들고 코드의 재사용성을 향상
- 구성
- __init__: 초기화를 담당하는 메서드로 모델의 레이어를 초기화합니다. 또한 학습 및 평가 과정에서 사용되는 손실 함수, 메트릭을 선언
- forward: 모델을 통해 데이터가 연산 되는 과정을 정의
- configure_optimizers: 최적화 알고리즘과 학습률 스케줄러를 정의하고 반환
- 반환할 때는 `return [optimizer], [scheduler]` 형식과 같이 순서를 맞춰야함 : 학습률 스케줄러는 생략가능
- training_step: 학습 데이터 셋의 미니 배치에 대해 손실을 반환하는 과정을 정의
- validation_step: validation set의 미니 배치에 대한 모델의 성능(손실, 메트릭)을 확인하는 과정을 정의
- test_step: test set의 미니 배치에 대한 모델의 성능(손실, 메트릭)을 확인하는 과정을 정의
- predict_step: 추론해야 하는 데이터 셋의 미니 배치에 대한 예측 과정을 정의
Trainer
- LightningModule의 메서드를 이용해 모델 학습을 실행하는 클래스
- 콜백 함수를 적절한 시점에 호출하거나 로깅 도구를 통해 학습 과정을 기록하는 과정도 자동으로 관리
- 메서드
- .fit(): LightningModule(model)과 train_dataloader, val_dataloader를 인자로 받아, 학습을 진행
- .validate(): LightningModule과 val_dataloader를 인자로 받아 validation set에 대한 평가를 진행
- .test(): LightningModule과 test_dataloader를 인자로 받아 test set에 대한 평가를 진행
- .predict(): LightningModule과 추론해야 하는 dataloader를 인자로 받아 모델의 결괏값을 반환
LightningModule 실습
- 설치 : https://lightning.ai/docs/pytorch/latest/starter/installation.html
- 로컬에 설치하는 경우, Window OS의 명령 프롬프트나 Mac OS, Linux OS의 터미널에서 `pip install lightning`을 통해 설치
- Colab 환경에서는 셀에 `!pip install lightning`으로 설치가능
- PyTorch로 CIFAR-10 데이터 셋을 사용해 분류하는 코드를 PyTorch Lightning으로 변환
- __init__
- 초기화 메서드에서는 PyTorch의 `nn.Module`의 __init__에서 선언한 레이어를 동일하게 작성하고 새로운 인자(e.g. 메트릭, 손실 함수)를 선언 - forward
- PyTorch의 forward 메서드는 동일하게 작성 - configure_optimizers
- 딥러닝 모델 학습에 사용될 최적화 알고리즘과 학습률 스케줄러를 작성
- 학습률 스케줄러는 생략가능 - training_step
- 미니 배치를 받아 학습 연산 과정을 거쳐 손실을 반환하는 코드를 작성 - validation_step
- 미니 배치를 받아 추론 과정을 거친 후, 로그를 기록하는 코드를 작성 - test_step
- `validation_step`과 마찬가지로 미니 배치를 받아 추론 과정을 거친 후, 로그를 기록하는 코드를 작성
- PyTorch 실습 코드는 예측 결과 및 라벨을 리스트에 추가하여 한 번에 메트릭을 계산했으나, PyTorch Lightning에서는 각 미니 배치의 메트릭을 계산해, 미니 배치수만큼의 메트릭을 평균 - predict_step
- 추론해야 하는 데이터 셋의 미니 배치에 대해 모델의 결괏값을 반환하는 코드를 작성
Trainer 실습
- 학습 과정에서 조기 종료(EarlyStopping)와 CSV 로깅 기능(CSVLogger)을 사용
- EarlyStopping : https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
- CSVLogger : https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.csv_logs.html#module-lightning.pytorch.loggers.csv_logs
- `CSVLogger`를 사용할 경우, 학습과 검증 과정에서 저장된 로그를 comma-separated values(CSV) 파일로 저장
- PyTorch Lightning의 Logger를 통해서 다양한 모니터링 툴을 편리하게 사용
- LightningModule에서 기록한 로그를 TensorBoard와 WandB를 사용하여 모니터링
- TensorBoard는 `TensorBoardLogger`, WandB는 `WandbLogger`를 사용
- TensorBoardLogger : https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.tensorboard.html#module-lightning.pytorch.loggers.tensorboard
- WandbLogger : https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.wandb.html#module-lightning.pytorch.loggers.wandb
'Study > 머신러닝' 카테고리의 다른 글
PyTorch Hydra (0) | 2024.12.17 |
---|---|
PyTorch 전이학습 (2) | 2024.12.17 |
딥러닝과 PyTorch (0) | 2024.12.17 |
텐서 조작, Tensor Manipulation(with PyTorch) (0) | 2024.12.16 |
Pytorch (0) | 2024.12.16 |