본 문서에서는 IJCAI 2019에서 발표된 Belief Propagation Network for Hard Inductive Semi-supervised Learning 논문을 소개합니다. 논문에 대한 상세한 정보는 다음과 같습니다.
- Title: Belief Propagation Network for Hard Inductive Semi-supervised Learning
- Authors: Jaemin Yoo, Hyunsik Jeon, U Kang
- Conference: The 28th International Joint Conference on Artificial Intelligence (IJCAI 2019)
Graph-Based Semi-Supervised Learning
많은 기계학습 문제는 지도 학습(supervised learning) 상황을 대상으로 합니다. 즉, 라벨(label)이 주어진 훈련 데이터가 충분히 주어져 있다고 가정합니다. 반면, 반지도 학습(semi-supervised learning) 문제는 훈련 데이터의 양은 충분하되 일부 샘플에 대해서만 라벨이 주어져 있다고 가정합니다. 예를 들면 피처 벡터의 개수는 1,000개이지만 라벨의 개수는 10개밖에 안 되는 상황이 그렇습니다. 이럴 때 10개의 데이터 쌍만 사용해서 모델을 제대로 훈련하기는 어렵기 때문에, 라벨이 주어진 10개의 데이터와 라벨이 없는 990개의 데이터 사이의 관계를 찾아 전체 데이터를 훈련에 사용하는 것이 중요합니다.
반지도 학습 문제는 각 샘플을 정점(node)으로 하는 그래프(graph)가 주어져 있다면 훨씬 수월해집니다. 이 그래프를 통해 샘플 사이의 관계를 쉽게 파악할 수 있기 때문입니다. 예를 들어 페이스북의 친구 관계 그래프가 주어져 있다면 사람들 사이의 유사도를 비교적 쉽게 파악할 수 있고, 이를 이용해 각 사용자의 속성을 분류하는 반지도 학습 문제를 해결할 수 있습니다. 이렇게 그래프 데이터가 주어진 상황에서 반지도 학습을 푸는 상황을 그래프 기반 반지도 학습(graph-based semi-supervised learning)이라고 합니다. 최근 그래프 신경망 기술의 발전으로 그래프 기반 반지도 학습 문제가 많은 연구자의 관심을 받아 왔습니다.
Hard Inductive Learning
그래프 기반 반지도 학습 문제는 훈련 및 테스트 상황을 어떻게 가정하느냐에 따라 여러 가지 분류로 나눌 수 있습니다. 본 논문은 그중에서 가장 강한 제약 조건을 갖는 hard inductive learning (HIL) 문제를 대상으로 합니다 (이 명칭은 본 논문에서 처음 제시하였습니다). HIL 문제는 다음과 같은 조건을 가정합니다.
- 훈련 상황: 일반적인 그래프 기반 반지도 학습에서와 같이 그래프와 전체 정점의 피처 벡터가 함께 주어집니다. 따라서 GCN(graph convolutional networks)이나 GAT(graph attention networks) 등 반지도 학습을 위한 기존의 모델을 그대로 사용할 수 있습니다.
- 테스트 상황: 기존 문제와는 달리 별개의 그래프 없이 개별 피처 벡터만 주어진다고 가정합니다. 즉, 연결 정보가 전혀 없는 개별 벡터가 한 번에 하나씩 입력되는 온라인(online) 상황입니다. 이때 주어지는 샘플 벡터는 훈련 상황에서는 관측할 수 없습니다. 따라서, 관계에 대한 정보를 전혀 사용하지 않고 새로운 샘플에 대한 예측을 수행할 수 있어야 합니다.
이러한 조건은 새로운 샘플에 대한 즉각적인 예측 결과를 출력해야 하는 상황을 가정합니다. 예를 들어, 새로운 사용자가 페이스북 그래프에 추가되었을 때는 해당 사용자의 친구 정보가 충분하지 않기 때문에 연결 관계를 중시하는 예측 모델이 좋은 성능을 내기 어렵습니다. 반면, HIL 문제를 가정하고 훈련된 모델은 테스트 상황에서 정점의 연결 관계가 주어지지 않는다고 가정하기 때문에 더 좋은 성능을 낼 수 있습니다.
기존의 그래프 신경망(graph neural networks) 모델은 샘플간 그래프가 훈련 상황 및 테스트 상황 모두에서 주어진다고 가정하고 정점간 연결 관계를 핵심적인 예측 근거로 사용합니다. HIL 문제를 잘 해결하기 위해서는 훈련 상황에서 주어진 그래프를 반지도 학습을 위해 최대한 사용하되, 실제로 모델이 활용되는 예측 상황에서는 각 정점의 피처 벡터만을 사용하여 정확한 결과를 출력할 수 있어야 합니다.
Belief Propagation Network
본 논문에서는 새로운 그래프 신경망 모델인 BPN(Belief Propagation Network)을 제안합니다. BPN 모델은 이름에서 알 수 있듯 신뢰 전파(belief propagation) 알고리즘을 반지도 학습의 핵심 아이디어로 사용합니다. 동일한 형태의 레이어를 여러 번 쌓아 그래프 정보를 사용하는 기존 그래프 신경망과 달리, BPN은 각 정점의 예측값을 생성하는 단계와 각 정점의 예측값을 서로 관련짓는 단계를 구분하여 사용합니다.
BPN 모델은 먼저 전방 전달 신경망(multilayer perceptron)을 사용해 각 정점의 피처 벡터를 기반으로 예측값을 생성합니다. 이 신경망은 독립된 모델로서 그래프 정보를 사용하지 않고 클래스 예측을 수행합니다. 그 다음, 신뢰 전파 알고리즘을 이용해 예측값을 그래프 전체로 전파하고 인접한 정점이 비슷한 예측값을 갖도록 합니다. 이는 주어진 그래프를 확률 모델로 취급하여 신경망이 생성한 예측값을 각 변수의 사전 분포로 가정한 뒤 모든 정점의 사후 확률을 계산함으로써 이루어집니다. 그 결과, 그래프 구조가 신뢰 전파 알고리즘을 통해 예측 과정에 반영되어 반지도 학습을 할 수 있게 됩니다. 이 전체 과정은 미분 가능한 딥 러닝 프레임워크로 구현되어 역전파(backpropagation)를 통한 학습이 가능하도록 합니다.
그러나, HIL 문제를 잘 해결하기 위해서는 추가적인 아이디어가 필요합니다. 테스트 상황에서 그래프가 주어지지 않으면 신뢰 전파 알고리즘을 사용할 수 없으므로 그에 따라 정확도가 하락하기 때문입니다. 그래서 본 논문에서는 분류를 위한 교차 엔트로피(cross entropy) 함수에 더해서 새로운 induction loss를 사용합니다. 이 induction loss는 각 정점에 대한 신경망의 예측값이 신뢰 전파 알고리즘을 수행한 결과와 동일해지도록 모델을 학습합니다. 즉, 신뢰 전파 알고리즘의 결과를 새로운 정답으로 가정하여 신경망을 학습합니다. 그 결과, 신경망은 테스트 시점에서 그래프가 주어지지 않아도 정확한 예측 결과를 생성할 수 있게 됩니다.
Experiments
본 논문에서는 정점 분류 문제를 위한 4개의 그래프 데이터에 대해 실험 결과를 제시합니다. 그 결과, 아래의 표에서 보이는 것처럼 BPN 모델이 기존 모델의 성능을 크게 앞서는 결과를 보였습니다. 이는 HIL 상황을 가정하여 예측 단계와 신뢰 전파 단계를 구분하였기 때문에, 테스트 상황에서 그래프가 주어지지 않더라도 예측 모델이 각 정점에 대한 정확한 예측을 수행할 수 있기 때문입니다. 반면, GCN, GAT 등 기존의 그래프 신경망 모델은 HIL 상황에서 성능이 크게 하락하는 결과를 보였습니다.
추가적으로, 각 정점을 예측하기 위한 신경망과 그래프 구조를 이용하는 신뢰 전파 알고리즘을 분리하는 특성상 BPN 모델은 결과 해석에 큰 강점을 갖습니다. 각 테스트 정점에 대한 예측이 그래프 정보를 사용하지 않고 독립적으로 이루어지기 때문입니다. 아래 그림은 Amazon 데이터에서 각 상품의 카테고리를 분류할 때, Cell Phones, Electronics 등 각 클래스를 예측하기 위해 중요하게 사용하는 피처를 탐색한 결과입니다. 빨간색으로 체크되어 있는 피처는 훈련 데이터 내의 라벨된 노드에 전혀 등장하지 않았음에도 불구하고 중요한 피처로서 활용되었습니다. 이는 BPN 모델이 반지도 학습을 잘 수행한다는 근거입니다.
Conclusion
본 문서에서는 IJCAI 2019에서 발표된 Belief Propagation Network for Hard Inductive Semi-supervised Learning 논문을 소개하였습니다. 해당 논문은 Hard inductive learning을 위한 새로운 그래프 신경망 모델인 BPN(Belief Propagation Network)을 제안하였습니다. 또한, 실험 결과를 통해서 BPN이 기존 그래프 신경망보다 더 정확한 예측을 한다는 것을 보였습니다. 최근 그래프 신경망에 대한 연구가 활발히 이루어지고 있지만, 대부분의 연구는 테스트 상황에서 질 좋은 그래프가 주어져 정점간 관계를 충분히 활용할 수 있다고 가정합니다. 반면, 본 연구는 고립된 테스트 정점에 대한 분류 정확도를 최대화하는 것을 목표로 합니다.
예를 들어, 페이스북에 새로운 사용자가 유입되었다고 할 때, 기존에 학습된 모델을 이 사용자에 적용한다면 충분한 친구 관계의 부재로 좋은 성능을 내기 어렵습니다. 반면, 본 논문에서 제안한 BPN 모델은 그래프 구조를 바탕으로 개별 정점의 피처-라벨간 관계를 학습하기 때문에, 새로 유입된 사용자나 다른 사용자와 거의 관계를 맺지 않는 사용자에 대해서도 좋은 성능을 보입니다. 이는 온라인 쇼핑몰에서 새롭게 등록된 상품을 구매자에게 추천하는 경우에도 동일하게 적용됩니다. 따라서, 본 연구는 그래프 구조가 불안정한 상황에서 기존의 그래프 신경망을 보완하고 개선할 수 있습니다. 자세한 내용은 해당
링크에서 확인하실 수 있습니다.