화학물질 독성 및 질환 발생 예측 연구를 위해 분자 데이터를 다룬 다양한 연구 사례 중 GROVER 논문을 읽어보았다. 해당 논문은 일반적으로 SMILES 방식으로 분자 데이터를 나타내던 기존 방식에서 벗어나 graph 형식으로 데이터를 표현하여 모델을 pre-train 시킨다.
이를 위해 transformer의 인코더 부분을 사용하였으며, self-supervised 방식으로 학습하였다. 인접 행렬로 된 graph data를 GROVER 모델을 통해 node embed, edge embed로 출력한다.
https://drug.ai.tencent.com/publications/GROVER.pdf
https://github.com/tencent-ailab/grover
코드는 공부 중이다!
✏️ 제안 배경
1) 분자 데이터의 한계점
: graph로 표현 가능 (node, edge로 구성)
[분자 데이터에서 기존 딥러닝 모델들의 한계점]
- 구조의 복잡성으로 인해 일반화 어려움
- Labeled data 부족: 시간, 비용 high → Self-supervised learning model 제안
2) 분자 데이터의 기존 표현 방식
- SMILES: 화학 물질의 구조를 문자열로 나타낸 방식 (sequential data)
- 분자의 구조에 대한 정보는 명확하게 나타낼 수 없어 Topology property 파악 어려움
→ 분자를 SMILES 방식이 아닌 graph 형식으로 represent하여 pre-train model 제안
(GNN 모델의 확장)
✏️ 선행 연구
1) self- supervised learning
- 모델의 일반성을 높이기 위한 방안으로 제시. labeled data 부족 이슈도 해결 가능
- 직접 라벨링하면서 학습하는 방식
- pretext task: 사용자가 label이 없는 데이터를 labeling하기 위해 직접 정의하는 task
[ pipeline]
- pretext task 지정
- labeling되어 있지 않은 데이터를 사용하여 지정된 pretext task 수행
- ➡️ pre-train(미리 학습)한다고 함
- pre-train한 모델의 성능 확인*transfer learning : pre-trained model을 사용하는 것, 이 단계는 supervised
- : downstream task에 가져와 wight는 고정시킨 채로 transfer learning
- fine-tuning: pre-trained model 의 성능 향상→ 더 적은 레이블의 새로운 데이터에도 적용 가능하도록 함
2) Self - attention Model
- multi-head attention: attention layer를 평행으로 쌓아 동시에 running
- input
Q: 단어에 대한 가중치
K: 단어가 Query와 얼마나 연관되었는지 비교하는 가중치
V: 의미에 대한 가중치 - output
Context vector
3) GNN
- G = (V,E) V = set of node, E = set of edge, 인접 행렬로 표현
- Node, edge, graph level에서의 예측 등에 사용
- 이웃 node와의 연결 관계를 활용하여 node의 상태를 update하고, 마지막 상태를 Node embedding 으로 출
- message passing
- 각 node마다 이웃 node와 연결 관계 수집
- 모든 정보 aggregate function 이용해 집계
- Update function 통해 업데이트 후 출력
✏️ Model architecture
- Graph Transformer (w/ graph multi-head self- attention)
- 2개의 Gtransformer로 구성 (Node GTransformer + Edge GTransformer)
- DYMPN (DyNamic Messaege Passing Network) 사용
- 전체적인 구조
1) Input graph (인접 행렬)→ 벡터화된 k,q,v 추출
[Dynamic message passing (DyMPN)]
- 각 node마다 이웃 node와 연결 관계 수집
- 모든 정보 aggregate function 이용해 집계
- Update function 통해 업데이트 후 출력
=> 업데이트 과정에서 iteration마다 update 횟수를 random하게 실시 (MPN과의 차이점)→ more general model 만들 수 있음
2) K,q,v → attention value (norm.) → node, edge level에서 각각 aggregate → concat
- Aggreate: 내 자신에게 나 제외한 이웃의 정보 합해서 전달
- Concat: 나와 내 이웃 정보 모두 종합 (이웃 정보 바탕으로 나 자신 update 하는 것)
3) 각 위치마다 병렬적 처리 → norm → node embed, edge embed
✏️ pre-train model
1) pretext task
a) Contextual property prediction
하나의 graph 내에서 subgraph가 어떤 cluster에 속하는지 classification
- Target node의 일부를 subgraph로 지정
- k개의 이웃 node, edge 추출→ 통계적 특성 추출
- Subgraph들은 추출된 특성에 따라 clustered
b) Graph-level motif prediction
graph 차원에서 motif가 어떤 category에 속하는지 classification
→ multi label classification
2) Fine-tuning for downstream task
[downstream task 종류]
- node level (e.g. node classification)
- edge level (e.g. link prediction)
- graph level (e.g. property prediction for molecules)
[fine tuning]
- pre train 과정에서 labeling한 supervised data의 일부를 사용
- GROVER encoder, READOUT function, 사용할 모델 (ex. MLP) 등을 튜닝
- 성능 향상 가능
✏️ 실험
1) pre- train data
- unlabeled data
- ZINC15, Chembl data set에서 1,100만개 samlpling Chembl: 약물과 유사한 특성을 가진 생체 활성 분자 데이터 ZINC15: 가상 환경에서 사용할 수 있는 화학 화합물 모음. 화합물을 3D 포맷으로 표현하고 있음
- 10% random 하게 split하여 validation으로 사용
2) fine-tuning tasks & dataset
- MoarNet 연구의 데이터를 benchmark dataset으로 사용
- benchmark dataset: 여러 실험 또는 모델의 성능을 비교할 수 있는 표준 데이터셋
생물, 물리학, 물리 화학, 양자 역학 등 다양한 분야
Molecular Classification Datasets.
- BBBP : 화합물이 혈액-뇌 장벽을 관통하는 투과성 특성을 가지고 있는지 여부
- SIDER : 부작용 자원이라고도 알려진 약물 부작용과 함께 시판되는 약품들
- ClinTox: FDA를 통해 승인된 약물과 임상시험 중 독성 때문에 제거된 약물을 비교
- BACE: 지난 몇 년 동안 인간 β-시크릿레이스 1(BACE-1)의 억제제로 작용할 수 있는 생태화합물
- Tox21: 화합물의 독성을 측정하는 공공 데이터베이스
- ToxCast :수천 개의 화합물에 대한 여러 개의 독성 라벨을 부착한 결과
Molecular Regression Datasets.
- QM7: HOMO/LUMO와 같이 안정적이고 합성적으로 접근 가능한 유기 분자의 계산된 원자화 에너지를 기록한 데이터의 일부, 삼중 결합 등 다양한 분자 구조 포함
- QM8: 컴퓨터가 생성한 양자 역학적 특성.
- ESOL: 화합물의 용해성
- Lipophilicity: 옥탄올/물 분포 계수 실험을 통해 얻은 데이터로 분자막 투과성과 용해도에 영향을 미치는 중요한 성질
- FreeSolv: 실험과 연금술 자유 에너지 계산 양쪽에서 물에 작은 분자의 수화 자유 에너지
Data Splitting
- scaffold splitting→ train/ validation/ test= 8:1:1
- 구조적으로 다른 분자는 다른 subset으로 나눔→ 분자 데이터의 성징 예측시 많이 사용되는 splitting method
- 실험 세팅에 있어 좀 더 현실적인 모델 성능 예측 가능
- 각각 dataset에 대하여 독립적으로 실시
Baseline Model
- MoleculeNet: 분자 속성에 대한 기계 학습 방식들을 테스트하기 위해 다양한 실험 및 모델 고안
- MoleculeNet의 SOTA model과 비교
- TF_Roubust : DNN 기반의 mulitask framework, input-molecular fingerprint (화합문 구조의 특징을 인진법화한 벡터)
- GraphConv, Weave, SchNet: Graph Conventional model
- MPNN. DMPNN, MGCN: message passing 과정에서 edge feature를 중점적으로 고려하는 모델
- AttentiveFP: graph attention network의 확장
- GROVER: N-gram & 일반-> self supervised learning에서의 효과 검증
3) experiment configuration
Goal: 분자 특성 예측
1) GROVER pre-train (사전 지정했던 pretext task 수행)
- optimizer: Adam optimizer (Adaptive Moment Estimation): 각 파라미터마다 다른 크기와 방향의 업데이트 적용 learning rate: 0.000015, epochs= 500
- Contextual property prediction: 하나의 graph 내에서 subgraph가 어떤 label가지는가?
- context 반지름을 k=1로 설정
- 각각 다른 2518개의 node, 2686개의 edge의 속성 추출 후 이를 label로 지정
- node와 edge label의 15%를 random하게 가리고 이를 예측
- Graph-level motif prediction: graph 차원에서 motif가 어떤 category에 속하는지 classification
- RDKit 사용하여 85개의 motif 추출 (motif: 그래프를 이해할 수 있도록 해주는 특성, pattern, recurring, significant의 특징 )
- motif의 label은 one-hot vector
- model size test: hyperparameter 동일하게 두고 hidden layer의 크기만 다르게 지정
- environment: 250 Nvidia V100 GPU
- time
2.5 days | 4 days |
2) Fine-tuning
- 평가: validation lass
- 100 epochs
- random search hyper parameter tuning
- result
대부분의 dataset에서 일관되게 좋은 성능 가짐. 다른 모델들 보다 상대적으로 평균 6.1%정도 성능 좋았음.
→ 매우 적은 레이블 정보로 좋은 성능을 보이는 GROVER 모델의 장점
✏️ Conclusion and Future Works
Conclusion
- 좋은 pretext task의 지정과 데이터를 잘 표현할 수 있는 모델 구조의 필요성 재고
- 분자 특성 벤치마크에서 현재 최고 성능이라고 여겨지는 모델보다 평균 6% 이상의 성능 향상 달성 가능
Future Work
- GNN pre-train 과제에서 pretext task의 지정은 핵심. 논문에서 제안한 것 외에도 다른 의미있는 task를 지정헀을 시에 성능 향상 기대 가능 ( ex.distance-preserving task, tasks that getting 3D input )
- downstream task의 확장: 해당 paper에서는 분자의 특성을 예측. 하지만 더 넓은 범주에서의 downstream task 지정 시도 가능 (node 예측, link 예측)→ 이에 맞는 다른 pre- train 방식이나 self- supervision 방식 연구 가능
- Wider and deeper models: 더 복잡한 작업에 대해 더 많은 의미 정보 포착 가능