Meta Learning - MAML
이번 포스팅에서는 Meta Learning 알고리즘 중 하나인 MAML을 공유해보려고 합니다.
개요
배경
인공신경망 기반의 모델들은 일반적으로 뛰어난 성능을 갖지만, 클래스 불균형 혹은 레이블 구분이 모호한 경우 기존 지도학습의 패러다임에서 벗어난 상황에서는 적용이 어렵습니다.
이러한 상황에 소량 데이터를 짧게 학습하여 최소한의 적응만으로 새로운 Task(여러 문제들)를 빠르게 배울수 있는 메타러닝 방법들이 활용되었습니다.
Few-shot Learning <-> Meta Learning
소량의 데이터로 학습하는 것을 Few-shot Learning 혹은 N-way K-shot Learning이라고 합니다. 이때 N은 클래스의 수, K는 클래스별 서포트(학습) 데이터 수를 의미합니다.Few-shot Learning에 Meta Learning을 도입하여 Classification 문제를 해결합니다.
Meta Learning이란, 모델 스스로가 학습 규칙을 도출할 수 있도록 구분하는 방법을 배우는 것을 말합니다(Learning to Learn). 접근 방식에 따라 여러 분류가 존재하지만, 본 포스팅에서는 Optimization-Based 기반의 Meta Learning의 알고리즘 중 하나인 MAML에 대해 설명합니다.
MAML (Model Agnostic Meta-Learning)
N-way K-shot 형태의 여러 Task의 샘플 데이터를 활용하여, 새로운 Task에도 빠르게 학습 가능한 Weight initial Point를 찾는 알고리즘입니다.
키워드
Model Agnostic
- 모델 종류에 상관없이 적용 가능합니다.
- Gradient Descent 방식을 사용하는 모든 모델에 적용 가능합니다.
Fast Adaptation
- 새로운 Task에 적은 Update로 빠르게 학습 가능합니다.
접근 방식
각 Task 1, 2, 3에 대해 코어 가중치 θ(검은 실선)가 각 Loss 로 인해 θ*1, 2, 3으로 변하는 것을 확인할 수 있습니다. 이는 Target Task 즉, 새로운 Task에 빠르게 최적화를 가능하게 합니다.
여러 Task에 대해서 적한한 모델을 학습하면 세부 Task에 적용할 때 Fine-Tuning 과정이 빠르게 될 것이다라는 아이디어로 접근한 것으로 보입니다.
Support set & Query set
Meta Learning에서는 데이터셋을 Support set과 Query set으로 부릅니다. Support set으로 학습을 진행하고 Query set으로 테스트를 진행합니다.
아래 그림과 같이, Meta-Learning에서는 Meta-Train D1~DN 개의 데이터로 Task를 샘플링해서 사용하며, 이때 각 Meta-Train Sampling Task set의 Support set으로 학습을 하며, Query set으로 각 Task를 추론 및 업데이트를 수행합니다.
학습
논문에서 제공하는 MAML 알고리즘입니다.
p(T)는 메타 학습 Task의 Task 분포, 전체 Task Set을 의미하며, 새로운 Task T_i에 있는 datapoints로 학습을 진행합니다. loss function를 최소화하는 방향으로 진행되며, 이는 기존 지도학습의 가중치 업데이트 방법과 같은 것으로 보입니다.
이때, 차이점은 학습 Task에서 몇개를 샘플링하고 각 Task에 대해서 역전파를 통해 각각의 θ_i를 구하고 이것을 전체 gradient에 θ를 업데이트합니다.
즉, 각 Task로부터 얻은 정보를 바탕으로 하나의 θ에 업데이트하며, θ를 시작점으로 몇 번의 혹은 적은 횟수의 gradient 업데이트를 통해 θ’을 찾는 것을 의미합니다.
테스트
Omniglot Dataset 사용 하여 MAML 모델의 Classification 결과를 확인합니다.
테스트 조건
Omniglot Dataset (1623개의 서로 다른 문자 데이터셋)
TensorFlow 1.15, Batch size 16, lr 1e-4
20-way 1-shot 적용합니다.
새로운 Task는 5 Gradient Step 으로 학습하여 결과를 확인합니다.
테스트 결과
논문에서는 정확도 95.8%(오차 0.3%)의 결과를 보였으며, 본 테스트 결과 93.38%의 결과를 도출했습니다.
이를 통해 새로운 Task 테스트 결과 initialization weight를 잘 수행한 것으로 보입니다.
정리 및 결론
일종의 배경지식을 모델에 훈련시켜 비슷한 종류의 새로운 Task 문제를 해결하는 MAML 알고리즘을 적용해보았습니다.
새로운 Task에 대해 빠르게 학습 가능한 가중치 초기점을 찾는 지와 예측 결과까지 확인해보았습니다.
93.38%는 분류 성능으로서는 높지 않지만, 학습 데이터가 부족한 경우에 효과적일 수 있으며, Model Agnostic 하다는 점에서 여러 활용이 가능할 것으로 보입니다.
참고 논문