Study/머신러닝

PyTorch Lightning

김 도경 2024. 12. 17. 15:30
PyTorch Lightning

 

  • 배경
    - 구현하는 코드의 양이 늘어나며 코드의 복잡성이 증가하고 이에 따라 다양한 요소들이 복잡하게 얽히게 됨
    - 이런 요소들은 서로 강하게 관계성을 가지게 되어, 한 부분을 변경하면 다른 부분에도 영향

  • PyTorch에 대한 high-level 인터페이스를 제공하는 오픈소스 라이브러리
    - High-level 인터페이스 : 복잡한 시스템이나 프로그램을 사용자에게 더 단순하고 이해하기 쉽게 만들어주는 인터페이스를 의미
    - 딥러닝 모델 구축의 코드 템플릿으로써 기능을 하여 코드를 작성할 때 좀 더 정돈되고 간결화된 코드를 작성
    - 코드 템플릿 : 특정 프로그래밍 작업을 간소화하거나 반복적인 코드 작성을 줄이기 위해 미리 정의된 코드 블록이나 구조를 의미

  • 코드의 추상화 및 하드웨어 호출 자동화
        - 코드의 추상화 : 복잡한 로직을 간단한 인터페이스 뒤에 숨기는 것을 의미
        - 프로그래머는 내부 로직에 대해 신경쓰지 않고, 필요한 기능을 쉽게 사용
    - PyTorch는 model, optimizer, training loop 등을 전부 따로따로 구현
    - LightningModule 클래스 안에 모든 것을 한 번에 구현
    - 클래스 내부의 메서드명은 모두 PyTorch Lightning에서 요구하는 대로 똑같이 써야 하며, 그 목적에 맞게 코딩
    - PyTorch Lightning에서는 학습에 필요한 하드웨어(CPU, GPU)를 자동으로 호출하여 사용해

  • 다양한 콜백 함수와 로깅
    - 다양한 내장(built-in) 콜백 함수를 지원하며, 이를 사용해 딥러닝 학습과 연관된 특정 기능을 편리하게 사용
    - 다양한 로깅 도구를 지원하여 로깅해야하는 값을 편리하게 기록하고 TensorBoard, WandB 등 모니터링 툴을 쉽게 사용

  • 16-bit precision
    - 일반적으로 딥러닝 모델에서 실수를 표현하는 비트 수는 32-bit인데, 이를 줄여 모델의 계산 속도 향상과 메모리 사용량을 줄이고자 한 아이디어가 16-bit precision
    - 최근 딥러닝 연구에 사용되는 모델의 크기는 대체로 큰 경향성
    - 이럴 경우, 모델 전체를 GPU에 로드하여 학습하고 사용하기에 제한
    -> PyTorch Lightning에서는 16-bit precision과 같은 복잡한 기능 또한 옵션으로 추가하여 편리하게 사용

LightningModule

- LightningModule 클래스를 상속받아 모델의 구조, 손실 함수, 학습 및 평가 방법과 최적화 알고리즘을 클래스에 선언
- 모델 구조와 학습 로직을 함께 클래스로 선언하여, 코드의 구조를 더욱 명확하게 만들고 코드의 재사용성을 향상

 

  • 구성
    - __init__: 초기화를 담당하는 메서드로 모델의 레이어를 초기화합니다. 또한 학습 및 평가 과정에서 사용되는 손실 함수, 메트릭을 선언
    - forward: 모델을 통해 데이터가 연산 되는 과정을 정의
    - configure_optimizers: 최적화 알고리즘과 학습률 스케줄러를 정의하고 반환
         - 반환할 때는 `return [optimizer], [scheduler]` 형식과 같이 순서를 맞춰야함 : 학습률 스케줄러는 생략가능
    - training_step: 학습 데이터 셋의 미니 배치에 대해 손실을 반환하는 과정을 정의
    - validation_step: validation set의 미니 배치에 대한 모델의 성능(손실, 메트릭)을 확인하는 과정을 정의
    - test_step: test set의 미니 배치에 대한 모델의 성능(손실, 메트릭)을 확인하는 과정을 정의
    - predict_step: 추론해야 하는 데이터 셋의 미니 배치에 대한 예측 과정을 정의
Trainer

- LightningModule의 메서드를 이용해 모델 학습을 실행하는 클래스

- 콜백 함수를 적절한 시점에 호출하거나 로깅 도구를 통해 학습 과정을 기록하는 과정도 자동으로 관리

  • 메서드
    - .fit(): LightningModule(model)과 train_dataloader, val_dataloader를 인자로 받아, 학습을 진행
    - .validate(): LightningModule과 val_dataloader를 인자로 받아 validation set에 대한 평가를 진행
    - .test(): LightningModule과 test_dataloader를 인자로 받아 test set에 대한 평가를 진행
    - .predict(): LightningModule과 추론해야 하는 dataloader를 인자로 받아 모델의 결괏값을 반환
LightningModule 실습

- 설치 : https://lightning.ai/docs/pytorch/latest/starter/installation.html
- 로컬에 설치하는 경우, Window OS의 명령 프롬프트나 Mac OS, Linux OS의 터미널에서 `pip install lightning`을 통해 설치
- Colab 환경에서는 셀에 `!pip install lightning`으로 설치가능

 

- PyTorch로 CIFAR-10 데이터 셋을 사용해 분류하는 코드를 PyTorch Lightning으로 변환

 

  • __init__
    - 초기화 메서드에서는 PyTorch의 `nn.Module`의 __init__에서 선언한 레이어를 동일하게 작성하고 새로운 인자(e.g. 메트릭, 손실 함수)를 선언

  • forward
    - PyTorch의 forward 메서드는 동일하게 작성

  • configure_optimizers
    - 딥러닝 모델 학습에 사용될 최적화 알고리즘과 학습률 스케줄러를 작성
    - 학습률 스케줄러는 생략가능

  • training_step
    - 미니 배치를 받아 학습 연산 과정을 거쳐 손실을 반환하는 코드를 작성

  • validation_step
    - 미니 배치를 받아 추론 과정을 거친 후, 로그를 기록하는 코드를 작성

  • test_step
    - `validation_step`과 마찬가지로 미니 배치를 받아 추론 과정을 거친 후, 로그를 기록하는 코드를 작성
    - PyTorch 실습 코드는 예측 결과 및 라벨을 리스트에 추가하여 한 번에 메트릭을 계산했으나, PyTorch Lightning에서는 각 미니 배치의 메트릭을 계산해, 미니 배치수만큼의 메트릭을 평균

  • predict_step
    - 추론해야 하는 데이터 셋의 미니 배치에 대해 모델의 결괏값을 반환하는 코드를 작성

 

Trainer 실습

- 학습 과정에서 조기 종료(EarlyStopping)와 CSV 로깅 기능(CSVLogger)을 사용
- EarlyStopping : https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
- CSVLogger : https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.csv_logs.html#module-lightning.pytorch.loggers.csv_logs
      - `CSVLogger`를 사용할 경우, 학습과 검증 과정에서 저장된 로그를 comma-separated values(CSV) 파일로 저장

- PyTorch Lightning의 Logger를 통해서 다양한 모니터링 툴을 편리하게 사용
- LightningModule에서 기록한 로그를 TensorBoard와 WandB를 사용하여 모니터링
- TensorBoard는 `TensorBoardLogger`, WandB는 `WandbLogger`를 사용
    - TensorBoardLogger : https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.tensorboard.html#module-lightning.pytorch.loggers.tensorboard
    - WandbLogger : https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.wandb.html#module-lightning.pytorch.loggers.wandb

 

'Study > 머신러닝' 카테고리의 다른 글

PyTorch Hydra  (0) 2024.12.17
PyTorch 전이학습  (2) 2024.12.17
딥러닝과 PyTorch  (0) 2024.12.17
텐서 조작, Tensor Manipulation(with PyTorch)  (0) 2024.12.16
Pytorch  (0) 2024.12.16