Example of 2D Graph Modality
Here we provide the example code for exploring the neural scaling law of 2D Graph Modality with GIN backbone. You can easily compare this example to other modality (Fingerprint, SMILES string and 3D graph) since we use similar code framework.
[1]:
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
from torch_geometric.loader import DataLoader
from splitter import random_split
from datasets.molnet import MoleculeDataset
from model.gnn import GNN
from model.mlp import MLP
1. Define the basic functions we need to use.
[2]:
def seed_all(seed):
if not seed:
seed = 0
print("[ Using Seed : ", seed, " ]")
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
return
def get_num_task(dataset):
# Get output dimensions of different tasks
if dataset == 'tox21':
return 12
elif dataset in ['hiv', 'bace', 'bbbp']:
return 1
elif dataset == 'muv':
return 17
elif dataset == 'toxcast':
return 617
elif dataset == 'sider':
return 27
elif dataset == 'clintox':
return 2
elif dataset == 'pcba':
return 92
raise ValueError('Invalid dataset name.')
seed = 0
seed_all(seed)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
[ Using Seed : 0 ]
2. Load dataset and partition it with random split. You can also easily replace it with other split methods (scaffold and imbalanced).
[3]:
dataset_name = 'hiv'
split = 'random'
num_tasks = get_num_task(dataset_name)
# Set your dataset directory
dataset_folder = '/home/YourDir/datasets/molecule_net/'
dataset = MoleculeDataset(dataset_folder + dataset_name, dataset=dataset_name)
print(dataset)
eval_metric = roc_auc_score
if split == 'random':
smiles_list = pd.read_csv(dataset_folder + dataset_name + '/processed/smiles.csv',
header=None)[0].tolist()
train_dataset, valid_dataset, test_dataset, (train_smiles, valid_smiles, test_smiles),_ = random_split(
dataset, null_value=0, frac_train=0.8, frac_valid=0.1,
frac_test=0.1, seed=42, smiles_list=smiles_list)
print('randomly split')
else:
raise ValueError('Invalid split option.')
print(train_dataset[0])
print('Training data length: {}'.format(len(train_smiles)))
Dataset: hiv
Data: Data(x=[1049163, 2], edge_index=[2, 2259376], edge_attr=[2259376, 2], id=[41127], fingerprint=[41127, 1024], y=[41127])
MoleculeDataset(41127)
randomly split
Data(y=[1], x=[39, 2], edge_index=[2, 84], fingerprint=[1, 1024], edge_attr=[84, 2], id=[1])
Training data length: 32901
3. Randomly select partial training set (1%, 5%, 10%, …, 100%) for model training.
[4]:
finetune_ratio = 0.01
batch_size = 256
finetune_num = int(finetune_ratio * len(train_smiles))
num_mols = len(train_dataset)
random.seed(seed)
all_idx = list(range(num_mols))
random.shuffle(all_idx)
ids = all_idx[:int(finetune_ratio * num_mols)]
train_dataset = train_dataset[ids]
train_loader = DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=8)
val_loader = DataLoader(valid_dataset, batch_size=batch_size,
shuffle=False, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
shuffle=False, num_workers=8)
4. Setup backbone model and optimizer
[5]:
num_layer = 5
emb_dim = 300
dropout_ratio = 0.5
lr = 1e-3
model_param_group = []
model = GNN(num_layer=num_layer, emb_dim=emb_dim, drop_ratio=dropout_ratio).to(device)
output_layer = MLP(in_channels=emb_dim, hidden_channels=emb_dim,
out_channels=num_tasks, num_layers=1, dropout=0).to(device)
model_param_group.append({'params': output_layer.parameters(),'lr': lr})
model_param_group.append({'params': model.parameters(), 'lr': lr})
print(model)
optimizer = optim.Adam(model_param_group, lr=lr, weight_decay=0)
GNN(
(x_embedding1): Embedding(120, 300)
(x_embedding2): Embedding(3, 300)
(gnns): ModuleList(
(0): GINConv()
(1): GINConv()
(2): GINConv()
(3): GINConv()
(4): GINConv()
)
(batch_norms): ModuleList(
(0): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
5. Train and evaluate the model.
[6]:
criterion = nn.BCEWithLogitsLoss(reduction='none')
train_roc_list, val_roc_list, test_roc_list = [], [], []
train_loss_list, val_loss_list, test_loss_list = [], [], []
roc_lists = []
best_val_roc, best_val_idx = -1, 0
optimal_loss = 1e10
def train_general(model, device, loader, optimizer):
model.train()
total_loss = 0
for step, batch in enumerate(loader):
batch = batch.to(device)
h = global_mean_pool(model(batch), batch.batch)
pred = output_layer(h)
y = batch.y.view(pred.shape).to(torch.float64)
is_valid = y ** 2 > 0
loss_mat = criterion(pred.double(), (y + 1) / 2)
loss_mat = torch.where(
is_valid, loss_mat,
torch.zeros(loss_mat.shape).to(device).to(loss_mat.dtype))
optimizer.zero_grad()
loss = torch.sum(loss_mat) / torch.sum(is_valid)
loss.backward()
optimizer.step()
total_loss += loss.detach().item()
global optimal_loss
temp_loss = total_loss / len(loader)
if temp_loss < optimal_loss:
optimal_loss = temp_loss
return total_loss / len(loader)
def eval_general(model, device, loader):
model.eval()
y_true, y_scores = [], []
total_loss = 0
for step, batch in enumerate(loader):
batch = batch.to(device)
with torch.no_grad():
h = global_mean_pool(model(batch), batch.batch)
pred = output_layer(h)
true = batch.y.view(pred.shape)
y_true.append(true)
y_scores.append(pred)
y_true = torch.cat(y_true, dim=0).cpu().numpy()
y_scores = torch.cat(y_scores, dim=0).cpu().numpy()
roc_list = []
for i in range(y_true.shape[1]):
# AUC is only defined when there is at least one positive data.
if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0:
is_valid = y_true[:, i] ** 2 > 0
roc_list.append(eval_metric((y_true[is_valid, i] + 1) / 2, y_scores[is_valid, i]))
else:
print('{} is invalid'.format(i))
if len(roc_list) < y_true.shape[1]:
print(len(roc_list))
print('Some target is missing!')
print('Missing ratio: %f' %(1 - float(len(roc_list)) / y_true.shape[1]))
return sum(roc_list) / len(roc_list), total_loss / len(loader), roc_list
[7]:
epochs = 100
train_func = train_general
eval_func = eval_general
for epoch in range(1, epochs + 1):
loss_acc = train_func(model, device, train_loader, optimizer)
print('Epoch: {}\nLoss: {}'.format(epoch, loss_acc))
train_roc = train_loss = 0
val_roc, val_loss, _ = eval_func(model, device, val_loader)
test_roc, test_loss, roc_list = eval_func(model, device, test_loader)
train_roc_list.append(train_roc)
val_roc_list.append(val_roc)
test_roc_list.append(test_roc)
train_loss_list.append(train_loss)
val_loss_list.append(val_loss)
test_loss_list.append(test_loss)
roc_lists.append(roc_list)
print('val: {:.6f}\ttest: {:.6f}'.format(val_roc, test_roc))
print()
if val_roc > best_val_roc:
best_val_roc = val_roc
best_val_idx = epoch - 1
Epoch: 1
Loss: 0.7890448979555018
val: 0.493066 test: 0.467586
Epoch: 2
Loss: 0.7206984094669178
val: 0.521066 test: 0.567369
Epoch: 3
Loss: 0.6808990557626126
val: 0.546758 test: 0.583905
Epoch: 4
Loss: 0.6745207162534314
val: 0.571512 test: 0.604082
Epoch: 5
Loss: 0.6650212782655167
val: 0.555260 test: 0.630019
Epoch: 6
Loss: 0.6609979158821531
val: 0.552788 test: 0.629470
Epoch: 7
Loss: 0.6365597884904656
val: 0.545063 test: 0.609288
Epoch: 8
Loss: 0.6325839394666748
val: 0.558695 test: 0.598746
Epoch: 9
Loss: 0.6194784707871628
val: 0.560854 test: 0.613101
Epoch: 10
Loss: 0.6119444207838427
val: 0.545393 test: 0.587247
Epoch: 11
Loss: 0.6084267317479832
val: 0.546163 test: 0.584681
Epoch: 12
Loss: 0.594379795679181
val: 0.550446 test: 0.580790
Epoch: 13
Loss: 0.5814032406936134
val: 0.558260 test: 0.591325
Epoch: 14
Loss: 0.5821246152011825
val: 0.572061 test: 0.609039
Epoch: 15
Loss: 0.5623563324240084
val: 0.572713 test: 0.608568
Epoch: 16
Loss: 0.5629411729916913
val: 0.561830 test: 0.589468
Epoch: 17
Loss: 0.5418217095759776
val: 0.564055 test: 0.575939
Epoch: 18
Loss: 0.5380657321434608
val: 0.590165 test: 0.533879
Epoch: 19
Loss: 0.5214687218438495
val: 0.612862 test: 0.511216
Epoch: 20
Loss: 0.5150101393634872
val: 0.624103 test: 0.540542
Epoch: 21
Loss: 0.4955445477030514
val: 0.634808 test: 0.578347
Epoch: 22
Loss: 0.48988493107218567
val: 0.632235 test: 0.596660
Epoch: 23
Loss: 0.4843246491776333
val: 0.634006 test: 0.589116
Epoch: 24
Loss: 0.46250356726144015
val: 0.610571 test: 0.546207
Epoch: 25
Loss: 0.44351343675443594
val: 0.585388 test: 0.517635
Epoch: 26
Loss: 0.4311040234292012
val: 0.584375 test: 0.506485
Epoch: 27
Loss: 0.42884370231803304
val: 0.583910 test: 0.502525
Epoch: 28
Loss: 0.41645678149041465
val: 0.593627 test: 0.503896
Epoch: 29
Loss: 0.3978502586092939
val: 0.617391 test: 0.506926
Epoch: 30
Loss: 0.3753017187324836
val: 0.635705 test: 0.532077
Epoch: 31
Loss: 0.37998483145674034
val: 0.628083 test: 0.533108
Epoch: 32
Loss: 0.3576763418170913
val: 0.589657 test: 0.512151
Epoch: 33
Loss: 0.33572181544907065
val: 0.559906 test: 0.526929
Epoch: 34
Loss: 0.3245714498859569
val: 0.563185 test: 0.564498
Epoch: 35
Loss: 0.3115902600749161
val: 0.571337 test: 0.583199
Epoch: 36
Loss: 0.30301906367055675
val: 0.564221 test: 0.549889
Epoch: 37
Loss: 0.29085193507140294
val: 0.552433 test: 0.515697
Epoch: 38
Loss: 0.280614545207421
val: 0.549898 test: 0.509258
Epoch: 39
Loss: 0.2638803713267179
val: 0.564261 test: 0.515690
Epoch: 40
Loss: 0.2784722773760675
val: 0.606568 test: 0.533551
Epoch: 41
Loss: 0.23600026165419447
val: 0.620526 test: 0.560151
Epoch: 42
Loss: 0.22487450013132906
val: 0.619791 test: 0.584050
Epoch: 43
Loss: 0.23739796695298396
val: 0.629150 test: 0.596498
Epoch: 44
Loss: 0.20377325739748045
val: 0.575724 test: 0.553914
Epoch: 45
Loss: 0.2042612950155936
val: 0.563912 test: 0.528640
Epoch: 46
Loss: 0.18829806694788767
val: 0.585577 test: 0.558088
Epoch: 47
Loss: 0.18532209452555687
val: 0.584719 test: 0.581878
Epoch: 48
Loss: 0.19661826269119664
val: 0.577035 test: 0.590247
Epoch: 49
Loss: 0.1777977850554482
val: 0.591134 test: 0.614577
Epoch: 50
Loss: 0.1688708345836757
val: 0.621861 test: 0.651868
Epoch: 51
Loss: 0.15549886980732563
val: 0.648588 test: 0.664025
Epoch: 52
Loss: 0.15489401141893733
val: 0.663913 test: 0.664687
Epoch: 53
Loss: 0.14611072828856853
val: 0.662711 test: 0.655157
Epoch: 54
Loss: 0.1414666344323933
val: 0.647173 test: 0.640932
Epoch: 55
Loss: 0.14527067568702584
val: 0.623229 test: 0.626670
Epoch: 56
Loss: 0.14593197044572598
val: 0.606861 test: 0.604167
Epoch: 57
Loss: 0.13939075120189384
val: 0.610515 test: 0.586252
Epoch: 58
Loss: 0.1251759732087336
val: 0.611797 test: 0.538693
Epoch: 59
Loss: 0.1343301963358481
val: 0.616589 test: 0.522969
Epoch: 60
Loss: 0.11001557274238116
val: 0.619285 test: 0.524458
Epoch: 61
Loss: 0.15277526986192563
val: 0.608661 test: 0.545176
Epoch: 62
Loss: 0.11538843484918088
val: 0.609680 test: 0.588619
Epoch: 63
Loss: 0.0985766438425654
val: 0.620455 test: 0.609299
Epoch: 64
Loss: 0.10735864472986267
val: 0.642322 test: 0.621622
Epoch: 65
Loss: 0.10676408858572192
val: 0.639177 test: 0.624693
Epoch: 66
Loss: 0.13325658329221335
val: 0.547236 test: 0.545921
Epoch: 67
Loss: 0.09812393359235637
val: 0.435594 test: 0.449668
Epoch: 68
Loss: 0.11815496698191
val: 0.440811 test: 0.456065
Epoch: 69
Loss: 0.13404194416895196
val: 0.489754 test: 0.533839
Epoch: 70
Loss: 0.09608796220740463
val: 0.537089 test: 0.593275
Epoch: 71
Loss: 0.10191841311042969
val: 0.561824 test: 0.609887
Epoch: 72
Loss: 0.11361937400858171
val: 0.591500 test: 0.622308
Epoch: 73
Loss: 0.10490253882491668
val: 0.610691 test: 0.619335
Epoch: 74
Loss: 0.09490131772865523
val: 0.618027 test: 0.596461
Epoch: 75
Loss: 0.1156856352257912
val: 0.643737 test: 0.597033
Epoch: 76
Loss: 0.09388995478264883
val: 0.652374 test: 0.595091
Epoch: 77
Loss: 0.10811315394162796
val: 0.610649 test: 0.574688
Epoch: 78
Loss: 0.10319929823191863
val: 0.581131 test: 0.597745
Epoch: 79
Loss: 0.10651510424792467
val: 0.609778 test: 0.595701
Epoch: 80
Loss: 0.11681465909044744
val: 0.635403 test: 0.608692
Epoch: 81
Loss: 0.0880049808133542
val: 0.606734 test: 0.588925
Epoch: 82
Loss: 0.10381243157490846
val: 0.573224 test: 0.583282
Epoch: 83
Loss: 0.08876446211393318
val: 0.564409 test: 0.588673
Epoch: 84
Loss: 0.09330504039850987
val: 0.560004 test: 0.597551
Epoch: 85
Loss: 0.08981156514877242
val: 0.557728 test: 0.602097
Epoch: 86
Loss: 0.10622473920822918
val: 0.566855 test: 0.599331
Epoch: 87
Loss: 0.07594522236525558
val: 0.644541 test: 0.594666
Epoch: 88
Loss: 0.07898343053583962
val: 0.656333 test: 0.596486
Epoch: 89
Loss: 0.10304699343774959
val: 0.663899 test: 0.598353
Epoch: 90
Loss: 0.09982285622781926
val: 0.647656 test: 0.589929
Epoch: 91
Loss: 0.08477227768232437
val: 0.637571 test: 0.601520
Epoch: 92
Loss: 0.08859610125810924
val: 0.656620 test: 0.615835
Epoch: 93
Loss: 0.08263467087315893
val: 0.627936 test: 0.593549
Epoch: 94
Loss: 0.09459970997364853
val: 0.599300 test: 0.580668
Epoch: 95
Loss: 0.07965596209709733
val: 0.632426 test: 0.583793
Epoch: 96
Loss: 0.06998769364210963
val: 0.640449 test: 0.580934
Epoch: 97
Loss: 0.08737638635008234
val: 0.634919 test: 0.576853
Epoch: 98
Loss: 0.09454356716638188
val: 0.638173 test: 0.580184
Epoch: 99
Loss: 0.07616965204530123
val: 0.643492 test: 0.600817
Epoch: 100
Loss: 0.0683504562626202
val: 0.633123 test: 0.627766
[10]:
print('best val: {:.6f}\ttest: {:.6f}'.format(val_roc_list[best_val_idx], test_roc_list[best_val_idx]))
best val: 0.663913 test: 0.664687
6. Gather the results from 5 independent runs.
Here we omit the results from the other runs and provid the integrated results from three classfication tasks (HIV, MUV and PCBA).
[19]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
hiv = np.array([
[0.5896, 0.6812, 0.7460, 0.7774, 0.7768, 0.7998, 0.8310, 0.8411, 0.8531],
[0.6647, 0.7389, 0.7801, 0.7874, 0.8040, 0.8039, 0.8279, 0.8237, 0.8710],
[0.6258, 0.7037, 0.7090, 0.8054, 0.7857, 0.7734, 0.8187, 0.8396, 0.8412],
[0.6458, 0.6869, 0.7464, 0.7768, 0.7636, 0.7932, 0.8067, 0.8415, 0.8501],
[0.6432, 0.6805, 0.7669, 0.7776, 0.8186, 0.7801, 0.7892, 0.8428, 0.8326]
])
muv = np.array([
[0.4381, 0.5454, 0.5818, 0.6727, 0.7621, 0.6712, 0.7054, 0.7617, 0.8068],
[0.5378, 0.6388, 0.6178, 0.7187, 0.7038, 0.7215, 0.7386, 0.7834, 0.8024],
[0.4707, 0.5556, 0.5380, 0.6610, 0.6668, 0.7252, 0.7222, 0.7891, 0.8183],
[0.4498, 0.5724, 0.5653, 0.6855, 0.6907, 0.6618, 0.7230, 0.7255, 0.7780],
[0.4755, 0.6006, 0.6749, 0.7079, 0.7379, 0.7317, 0.7554, 0.8083, 0.7814]
])
pcba = np.array([
[0.0524, 0.0958, 0.1296, 0.1750, 0.1961, 0.2161, 0.2468, 0.2673, 0.2829],
[0.0502, 0.0942, 0.1339, 0.1823, 0.2087, 0.2303, 0.2534, 0.2733, 0.2809],
[0.0526, 0.0961, 0.1339, 0.1691, 0.1912, 0.2161, 0.2443, 0.2625, 0.2735],
[0.0512, 0.0952, 0.1302, 0.1836, 0.2048, 0.2203, 0.2551, 0.2662, 0.2791],
[0.0542, 0.0995, 0.1327, 0.1775, 0.2000, 0.2262, 0.2448, 0.2713, 0.2794]
])
data_list = [hiv, muv, pcba]
7. Fit the curve and plot the neural scaling law.
[24]:
def power_law(x, a, b):
return a * np.power(x, b)
fig, axs = plt.subplots(2, 3, figsize=(15, 7))
x = [0.01,0.05,0.1,0.2,0.3,0.4,0.6,0.8,1]
for i,data in enumerate(data_list):
mean = np.mean(data, axis=0)
std = np.std(data, axis=0)
params, _ = curve_fit(power_law, x, mean)
mean_data = np.mean(data, axis=0)
std_data = np.std(data, axis=0) / 2
max_data = mean_data + std_data
min_data = mean_data - std_data
axs[0][i].fill_between(x, max_data, min_data, alpha=0.3)
axs[0][i].scatter(x, mean_data, marker='o',color='royalblue',s=30)
axs[0][i].plot(x, power_law(x, params[0], params[1]), linewidth=1.5, c='blue')
axs[1][i].fill_between(x, max_data, min_data, alpha=0.3)
axs[1][i].scatter(x, mean_data, marker='o',color='royalblue',s=30)
axs[1][i].plot(x, power_law(x, params[0], params[1]), linewidth=1.5, c='blue')
axs[0][i].set_xlabel('Select Ratio')
axs[0][i].set_ylabel('ROC-AUC')
axs[0][i].set_xticks(x)
axs[0][i].set_xticklabels(['1%',None,'10%',None,'30%',None,'60%',None,'100%'])
axs[1][i].set_xlabel('Select Ratio')
axs[1][i].set_ylabel('ROC-AUC')
axs[1][i].set_xticks(np.log(x))
axs[1][i].set_xticklabels(['1%',None,'10%',None,'30%',None,'60%',None,'100%'])
axs[1][i].set_xscale('log')
axs[1][i].set_yscale('log')
axs[0][2].set_ylabel('AP')
axs[1][2].set_ylabel('AP')
axs[0][0].set_title('(a) HIV')
axs[0][1].set_title('(b) MUV')
axs[0][2].set_title('(c) PCBA')
axs[1][0].set_title('(d) HIV (log-log)')
axs[1][1].set_title('(e) MUV (log-log)')
axs[1][2].set_title('(f) PCBA (log-log)')
fig.tight_layout()
plt.show()