https://github.com/chao1224/GraphMVP
https://openreview.net/pdf?id=xQUe1pOKPam
Github이 비어 있어 하고 있는 연구로의 적용은 어려울 것 같다는 생각이 들지만, 3D 구조 적용의 중요성에 대하여 재고해볼 수 있었다.
이 논문을 맨 처음 읽었을 때는 코드가 공개되어 있지 않았는데 최근 업데이트가 된 것 같다. (V. 220415)
regression 과 classification task 모두 적용 가능하다.
Prior knowledge
2D topology & 3D geometric
두 개념 모두 그래프가 전체적으로 이루는 모양과 구조적 정보를 의미한다. 최근 분자 성질 예측에서 3D geometry 정보가 더 중요하다는 사실 밝혀졌다. 이에 본 논문에서는 3D 구조를 반영할 수 있는 Graph Multi-viwe Pre-training 방식 제안하고자 하였다.
그렇다면 2D topology와 3D geometric의 차이는 무엇일까? 그림과 식으로 확인해보자.
✅ 2D Topology
Topology는 한국어로는 '위상' 으로 표현된다. 말 그대로 분자를 이루는 원자 하나 하나가 어떤 식으로 연결되어 있는지, 어떤 구조로 연결되어 있는지 등을 이루는 개념이다.
2D로 표현할 경우, 3D에 비해서는 입체적인 표현이 어렵다. 따라서 입체적인 구조 정보를 담기에는 한계가 있다. 따라서 원자들간의 연결 여부, 즉 인접성을 활용하여 2D 표현식의 구조 정보를 반영하고자 한다.
✅ graph는 node와 edge로 이루어져 있으며 분자를 graph로 나타냈을 때, node는 원자 하나를 의미한다.
✅ adjacency matrix: 그래프에서 어느 꼭짓점들이 변으로 연결되었는지 나타내는 정사각 행렬
구조 정보를 어떻게 반영할 수 있을까? 본 연구에서는 2D와 3D 표현식에 각각 2D transformation, 3D transformation을 적용해준 후 GNN을 모델에 fitting 시켜 $h_{x}$와 $h_{y}$ 라는 representation을 얻었다. 이를 input으로 pretrain model을 구축한다.
이를 표현하는 식은 다음과 같다.
$$h_{2D} = GNN-2D(T_{2D}(_{G2D}))= GNN - 2D(T_{2D}(X,E))$$
$G2D$ : 2D 분자 그래프 $(X,E)$
$T_{2D}$ : 2D transformer function
$h_{2D}$ : transformation 결과 반환되는 결과. 2D 분자 그래프를 벡터로 표현한 결과
$GNN-2D$ : Graph Isomorphism Network (GIN) model 활용, output- feature vector
GIN에 대한 내용은 추후 따로 포스팅하도록 하겠다.
✅ 3D geometric
geometry은 2D의 단순 노드끼리의 연결관계를 표현하는 것을 넘어, 3D 차원에서 Point, Polyline, Polygon 등의 요소를 포함한다. 이로써 좀 더 폭넓은 구조적 정보를 담을 수 있다. 3D geometric을 설명하기 위해 나오는 개념이 2가지 있는데, spatial position과 energy surface다.
spatial position은 2개 이상의 geometry가 공간상에서 가지게되는 위치를 의미한다. 다양한 공간적 관계에서 가지는 위치를 의미하며, 각 atom이 3D 공간에서 어떤 관계를 가지는지 (포함, 연결 등)를 나타낼 수 있다. (https://en.wikipedia.org/wiki/Spatial_relation)
energy surface는 화학에서 쓰이는 개념으로 , 분자의 에너지와 분자 기하학 정보 사이의 관계를 의미한다. 분자가 어떤 구조를 가지느냐에 따라 가지는 에너지가 달라진다.
https://link.springer.com/chapter/10.1007/978-3-319-30916-3_2
본 논문에서는 3D Molecular Graph는 추가적으로 원자의 spatial position을 포함한다고 한다. 이러한 spatial position은 원자들이 potential energy surface 위에서 연속적으로 움직이는데, 여기서 추가적으로 제시되는 개념이 conformer이다. confomer는 energy surface에서 지속적으로 움직이는 원자들이 local minimun에 위치한 3D 구조를 의미한다. 이 3D 구조는 분자의 중요한 특성으로 사용된다.
$$h_{3D} = GNN-3D(T_{3D}(_{G3D}))= GNN - 3D(T_{3D}(X,R))$$
$G3D$ : 3D 분자 그래프 $(X,R)$ / $R$ : 3D-coordinate matrix
$T_{3D}$ : 3D transformer function
$h_{3D}$ : transformation 결과 반환되는 결과. 3D 분자 그래프를 벡터로 표현한 결과
$GNN-3D$ : SchNet model 활용, output- feature vecto
Architecture of model
💡 main idea
pre-train 과정에서 2D view 데이터 뿐만 아니라 3D view 데이터를 활용하여 구조 정보를 더 잘 학습할 수 있도록 함
→ fine-tuning 과정에서는 2D view data만을 활용하면서 3D 정보 활용 가능하도록 함
일반적인 GNN 모델에서 input으로 들어가는 graph는 vector로 표현되어 있다. 본 논문에서는 vector가 아닌 graph 그 자체를 input으로 넣어 3D 정보를 반영하려 했다는 데에서 의의가 있다.
모델의 Main idea는 pretrained model 생성 시에 3D view 데이터를 사용하여 구조 정보를 잘 학습하도록 하는 것에 있다. high level 구조 정보도 잘 캐치하면서, 실제 downstream task에서는 3D 데이터를 배제하여 수행 속도를 높였다.
모델 아키텍처는 다음과 같다.
📚 GraphMVP
2D 데이터와 3D 데이터를 각각 2D transformer, 3D transformer 활용하여 벡터화하고 이를 $h_{x}$, $h_{y}$ 로 정의한다. 그 후 , 2개의 pretext task를 통해 3D view 정보(energy, conformer)를 잘 반영할 수 있는 pre-train model을 생성한다.
📚 pretext task
1) contrastive learining
constrastive learning에서는 분자를 표현하는 단계에서의 loss를 계산한다. 분자를 어떤 식으로 표현했을 때 best일지를 학습하는 것이다.
$h_{x}$와 $h_{y}$ 를 짝 지어 pair data를 생성한다. 이 pair data가 각각 같은 분자를 나타내는 벡터라면 (+), 다른 분자를 나타내는 벡터라면 (-)를 부여한다. 그 후 , (+)에 더 많은 weight를 할당하는 softmax 분류기 생성한다.
$L_{InfoNCE}$는 n개의 방식으로 계산한 sofrmax에 대한 cross entropy loss를 의미한다. InfoNCE는 Nosie constrastive estimation을 의미한다.
$f_{x}(x,y)$ =$f_{y}(x,y)$ = $exp(<h_{x}, h_{y}>)$ 는 score function을 의미한다.
계산 방식은 3D view sample을 하나 고정 (achor point로 지정) 후 2D view랑 비교하여 softmax / 2D view sample 하나 고정 (achor point) 후 3D view랑 비교하여 softmax 하는 식으로 계산된다.
2) generative task
generative task는 output을 도출하는 단계에서의 loss를 계산한다. 각 data 자체를 재구성하여 효과적으로 분자 구조를 잡아낸다. 즉, 2D, 3D 각각 공간정보를 재정의해서 나타내고 잘 나타냈는지 평가하는 것이다.
이 단계에서는 generative model 중 하나인 variational auto encoder(VAE)를 변형하여 사용한다. 앞의 constrastive task와 유사하게 2D topoly에 대응하는 3D conformer를 생성하고 이 생성된 정보가 3D 구조를 잘 나타내고 있는지 계산한다.
이를 위해 본 논문에서는 variational representation reconstruction(VRR) 방식을 제안한다. loss 계산식은 다음과 같다.
SG는 stop-gradient를 의미하며, KL은 KL-Divergence이다. KL-Divergence (KLD)는 P분포와 Q분포가 얼마나 다른지 측정하는 기법이다. 값이 낮을수록 두 분포 유사하다고 할 수 있다.
3) Objective function
2개의 pretext task에 가중치를 각각 달리하여 최종 모델 평가를 위한 GraphMVP-G와 GraphMVP-C를 계산한다.
$L_{GraphMVP-G}$ = $L_{GraphMVP-G$ + $alpha_{3}$ * $L{Generative 2D-SSL}$
$L_{GraphMVP-C}$ = $L_{GraphMVP-G$ + $alpha_{3}$ * $L{Constrastive 2D-SSL}$
Experiment
실험에 사용된 데이터는 다음과 같다.
[pre-train data]
- 50k qualified molecules from GEOM (2D, 3D stsructure 모두 존재하는 데이터)
[Downstream data]
moleculenet에 있는 benchmark dataset에 대해 실험하였으며, classification과 regression task에 대하여 모두 실험을 적용하였다
Result
regression (RMSE)
classification (AUC)
기존 SOTA model들보다 대부분은 좋은 성능을 보였다.
GraphMVP (Graph Multi-viwe Pre-training) 논문의 가장 큰 핵심은 3D 구조 정보를 활용하고자 시도했다는 점이다. 3D 정보를 담기 위해 다양한 시도들이 오가고 있는 것 같은데 3D 구조에서만 확인할 수 있는 conformation 요소가 분자 특성을 예측하는데에 얼마나 큰 기여를 할 수 있을지 궁금하다. 대부분의 task에 대해 좋은 성능을 보였음을 해당 논문에서 확인할 수 있었지만 3D 정보를 담기 위한 다른 방법은 없을지 생각해보게 된다.