Example of 2D Graph Modality

Here we provide the example code for exploring the neural scaling law of 2D Graph Modality with GIN backbone. You can easily compare this example to other modality (Fingerprint, SMILES string and 3D graph) since we use similar code framework.

[1]:
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score

import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import global_mean_pool
from torch_geometric.loader import DataLoader
from splitter import random_split


from datasets.molnet import MoleculeDataset
from model.gnn import GNN
from model.mlp import MLP

1. Define the basic functions we need to use.

[2]:
def seed_all(seed):
    if not seed:
        seed = 0
    print("[ Using Seed : ", seed, " ]")
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    return

def get_num_task(dataset):
    # Get output dimensions of different tasks
    if dataset == 'tox21':
        return 12
    elif dataset in ['hiv', 'bace', 'bbbp']:
        return 1
    elif dataset == 'muv':
        return 17
    elif dataset == 'toxcast':
        return 617
    elif dataset == 'sider':
        return 27
    elif dataset == 'clintox':
        return 2
    elif dataset == 'pcba':
        return 92
    raise ValueError('Invalid dataset name.')

seed = 0
seed_all(seed)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

[ Using Seed :  0  ]

2. Load dataset and partition it with random split. You can also easily replace it with other split methods (scaffold and imbalanced).

[3]:
dataset_name = 'hiv'
split = 'random'
num_tasks = get_num_task(dataset_name)
# Set your dataset directory
dataset_folder = '/home/YourDir/datasets/molecule_net/'
dataset = MoleculeDataset(dataset_folder + dataset_name, dataset=dataset_name)

print(dataset)
eval_metric = roc_auc_score

if split == 'random':
    smiles_list = pd.read_csv(dataset_folder + dataset_name + '/processed/smiles.csv',
                                header=None)[0].tolist()
    train_dataset, valid_dataset, test_dataset, (train_smiles, valid_smiles, test_smiles),_ = random_split(
        dataset, null_value=0, frac_train=0.8, frac_valid=0.1,
        frac_test=0.1, seed=42, smiles_list=smiles_list)
    print('randomly split')

else:
    raise ValueError('Invalid split option.')
print(train_dataset[0])
print('Training data length: {}'.format(len(train_smiles)))
Dataset: hiv
Data: Data(x=[1049163, 2], edge_index=[2, 2259376], edge_attr=[2259376, 2], id=[41127], fingerprint=[41127, 1024], y=[41127])
MoleculeDataset(41127)
randomly split
Data(y=[1], x=[39, 2], edge_index=[2, 84], fingerprint=[1, 1024], edge_attr=[84, 2], id=[1])
Training data length: 32901

3. Randomly select partial training set (1%, 5%, 10%, …, 100%) for model training.

[4]:
finetune_ratio = 0.01
batch_size = 256
finetune_num = int(finetune_ratio * len(train_smiles))
num_mols = len(train_dataset)
random.seed(seed)
all_idx = list(range(num_mols))
random.shuffle(all_idx)
ids = all_idx[:int(finetune_ratio * num_mols)]
train_dataset = train_dataset[ids]

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            shuffle=True, num_workers=8)
val_loader = DataLoader(valid_dataset, batch_size=batch_size,
                        shuffle=False, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
                            shuffle=False, num_workers=8)

4. Setup backbone model and optimizer

[5]:
num_layer = 5
emb_dim = 300
dropout_ratio = 0.5
lr = 1e-3

model_param_group = []
model = GNN(num_layer=num_layer, emb_dim=emb_dim, drop_ratio=dropout_ratio).to(device)
output_layer = MLP(in_channels=emb_dim, hidden_channels=emb_dim,
                    out_channels=num_tasks, num_layers=1, dropout=0).to(device)

model_param_group.append({'params': output_layer.parameters(),'lr': lr})
model_param_group.append({'params': model.parameters(), 'lr': lr})

print(model)
optimizer = optim.Adam(model_param_group, lr=lr, weight_decay=0)
GNN(
  (x_embedding1): Embedding(120, 300)
  (x_embedding2): Embedding(3, 300)
  (gnns): ModuleList(
    (0): GINConv()
    (1): GINConv()
    (2): GINConv()
    (3): GINConv()
    (4): GINConv()
  )
  (batch_norms): ModuleList(
    (0): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

5. Train and evaluate the model.

[6]:
criterion = nn.BCEWithLogitsLoss(reduction='none')
train_roc_list, val_roc_list, test_roc_list = [], [], []
train_loss_list, val_loss_list, test_loss_list = [], [], []
roc_lists = []
best_val_roc, best_val_idx = -1, 0
optimal_loss = 1e10

def train_general(model, device, loader, optimizer):
    model.train()
    total_loss = 0

    for step, batch in enumerate(loader):
        batch = batch.to(device)
        h = global_mean_pool(model(batch), batch.batch)
        pred = output_layer(h)

        y = batch.y.view(pred.shape).to(torch.float64)
        is_valid = y ** 2 > 0
        loss_mat = criterion(pred.double(), (y + 1) / 2)
        loss_mat = torch.where(
            is_valid, loss_mat,
            torch.zeros(loss_mat.shape).to(device).to(loss_mat.dtype))

        optimizer.zero_grad()
        loss = torch.sum(loss_mat) / torch.sum(is_valid)
        loss.backward()
        optimizer.step()
        total_loss += loss.detach().item()

    global optimal_loss
    temp_loss = total_loss / len(loader)
    if temp_loss < optimal_loss:
        optimal_loss = temp_loss

    return total_loss / len(loader)


def eval_general(model, device, loader):
    model.eval()
    y_true, y_scores = [], []
    total_loss = 0

    for step, batch in enumerate(loader):
        batch = batch.to(device)
        with torch.no_grad():
            h = global_mean_pool(model(batch), batch.batch)
            pred = output_layer(h)

        true = batch.y.view(pred.shape)

        y_true.append(true)
        y_scores.append(pred)


    y_true = torch.cat(y_true, dim=0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim=0).cpu().numpy()

    roc_list = []
    for i in range(y_true.shape[1]):
        # AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0:
            is_valid = y_true[:, i] ** 2 > 0
            roc_list.append(eval_metric((y_true[is_valid, i] + 1) / 2, y_scores[is_valid, i]))
        else:
            print('{} is invalid'.format(i))

    if len(roc_list) < y_true.shape[1]:
        print(len(roc_list))
        print('Some target is missing!')
        print('Missing ratio: %f' %(1 - float(len(roc_list)) / y_true.shape[1]))

    return sum(roc_list) / len(roc_list), total_loss / len(loader), roc_list
[7]:
epochs = 100

train_func = train_general
eval_func = eval_general

for epoch in range(1, epochs + 1):
    loss_acc = train_func(model, device, train_loader, optimizer)
    print('Epoch: {}\nLoss: {}'.format(epoch, loss_acc))

    train_roc = train_loss = 0

    val_roc, val_loss, _ = eval_func(model, device, val_loader)
    test_roc, test_loss, roc_list = eval_func(model, device, test_loader)

    train_roc_list.append(train_roc)
    val_roc_list.append(val_roc)
    test_roc_list.append(test_roc)
    train_loss_list.append(train_loss)
    val_loss_list.append(val_loss)
    test_loss_list.append(test_loss)
    roc_lists.append(roc_list)
    print('val: {:.6f}\ttest: {:.6f}'.format(val_roc, test_roc))
    print()

    if val_roc > best_val_roc:
        best_val_roc = val_roc
        best_val_idx = epoch - 1
Epoch: 1
Loss: 0.7890448979555018
val: 0.493066   test: 0.467586

Epoch: 2
Loss: 0.7206984094669178
val: 0.521066   test: 0.567369

Epoch: 3
Loss: 0.6808990557626126
val: 0.546758   test: 0.583905

Epoch: 4
Loss: 0.6745207162534314
val: 0.571512   test: 0.604082

Epoch: 5
Loss: 0.6650212782655167
val: 0.555260   test: 0.630019

Epoch: 6
Loss: 0.6609979158821531
val: 0.552788   test: 0.629470

Epoch: 7
Loss: 0.6365597884904656
val: 0.545063   test: 0.609288

Epoch: 8
Loss: 0.6325839394666748
val: 0.558695   test: 0.598746

Epoch: 9
Loss: 0.6194784707871628
val: 0.560854   test: 0.613101

Epoch: 10
Loss: 0.6119444207838427
val: 0.545393   test: 0.587247

Epoch: 11
Loss: 0.6084267317479832
val: 0.546163   test: 0.584681

Epoch: 12
Loss: 0.594379795679181
val: 0.550446   test: 0.580790

Epoch: 13
Loss: 0.5814032406936134
val: 0.558260   test: 0.591325

Epoch: 14
Loss: 0.5821246152011825
val: 0.572061   test: 0.609039

Epoch: 15
Loss: 0.5623563324240084
val: 0.572713   test: 0.608568

Epoch: 16
Loss: 0.5629411729916913
val: 0.561830   test: 0.589468

Epoch: 17
Loss: 0.5418217095759776
val: 0.564055   test: 0.575939

Epoch: 18
Loss: 0.5380657321434608
val: 0.590165   test: 0.533879

Epoch: 19
Loss: 0.5214687218438495
val: 0.612862   test: 0.511216

Epoch: 20
Loss: 0.5150101393634872
val: 0.624103   test: 0.540542

Epoch: 21
Loss: 0.4955445477030514
val: 0.634808   test: 0.578347

Epoch: 22
Loss: 0.48988493107218567
val: 0.632235   test: 0.596660

Epoch: 23
Loss: 0.4843246491776333
val: 0.634006   test: 0.589116

Epoch: 24
Loss: 0.46250356726144015
val: 0.610571   test: 0.546207

Epoch: 25
Loss: 0.44351343675443594
val: 0.585388   test: 0.517635

Epoch: 26
Loss: 0.4311040234292012
val: 0.584375   test: 0.506485

Epoch: 27
Loss: 0.42884370231803304
val: 0.583910   test: 0.502525

Epoch: 28
Loss: 0.41645678149041465
val: 0.593627   test: 0.503896

Epoch: 29
Loss: 0.3978502586092939
val: 0.617391   test: 0.506926

Epoch: 30
Loss: 0.3753017187324836
val: 0.635705   test: 0.532077

Epoch: 31
Loss: 0.37998483145674034
val: 0.628083   test: 0.533108

Epoch: 32
Loss: 0.3576763418170913
val: 0.589657   test: 0.512151

Epoch: 33
Loss: 0.33572181544907065
val: 0.559906   test: 0.526929

Epoch: 34
Loss: 0.3245714498859569
val: 0.563185   test: 0.564498

Epoch: 35
Loss: 0.3115902600749161
val: 0.571337   test: 0.583199

Epoch: 36
Loss: 0.30301906367055675
val: 0.564221   test: 0.549889

Epoch: 37
Loss: 0.29085193507140294
val: 0.552433   test: 0.515697

Epoch: 38
Loss: 0.280614545207421
val: 0.549898   test: 0.509258

Epoch: 39
Loss: 0.2638803713267179
val: 0.564261   test: 0.515690

Epoch: 40
Loss: 0.2784722773760675
val: 0.606568   test: 0.533551

Epoch: 41
Loss: 0.23600026165419447
val: 0.620526   test: 0.560151

Epoch: 42
Loss: 0.22487450013132906
val: 0.619791   test: 0.584050

Epoch: 43
Loss: 0.23739796695298396
val: 0.629150   test: 0.596498

Epoch: 44
Loss: 0.20377325739748045
val: 0.575724   test: 0.553914

Epoch: 45
Loss: 0.2042612950155936
val: 0.563912   test: 0.528640

Epoch: 46
Loss: 0.18829806694788767
val: 0.585577   test: 0.558088

Epoch: 47
Loss: 0.18532209452555687
val: 0.584719   test: 0.581878

Epoch: 48
Loss: 0.19661826269119664
val: 0.577035   test: 0.590247

Epoch: 49
Loss: 0.1777977850554482
val: 0.591134   test: 0.614577

Epoch: 50
Loss: 0.1688708345836757
val: 0.621861   test: 0.651868

Epoch: 51
Loss: 0.15549886980732563
val: 0.648588   test: 0.664025

Epoch: 52
Loss: 0.15489401141893733
val: 0.663913   test: 0.664687

Epoch: 53
Loss: 0.14611072828856853
val: 0.662711   test: 0.655157

Epoch: 54
Loss: 0.1414666344323933
val: 0.647173   test: 0.640932

Epoch: 55
Loss: 0.14527067568702584
val: 0.623229   test: 0.626670

Epoch: 56
Loss: 0.14593197044572598
val: 0.606861   test: 0.604167

Epoch: 57
Loss: 0.13939075120189384
val: 0.610515   test: 0.586252

Epoch: 58
Loss: 0.1251759732087336
val: 0.611797   test: 0.538693

Epoch: 59
Loss: 0.1343301963358481
val: 0.616589   test: 0.522969

Epoch: 60
Loss: 0.11001557274238116
val: 0.619285   test: 0.524458

Epoch: 61
Loss: 0.15277526986192563
val: 0.608661   test: 0.545176

Epoch: 62
Loss: 0.11538843484918088
val: 0.609680   test: 0.588619

Epoch: 63
Loss: 0.0985766438425654
val: 0.620455   test: 0.609299

Epoch: 64
Loss: 0.10735864472986267
val: 0.642322   test: 0.621622

Epoch: 65
Loss: 0.10676408858572192
val: 0.639177   test: 0.624693

Epoch: 66
Loss: 0.13325658329221335
val: 0.547236   test: 0.545921

Epoch: 67
Loss: 0.09812393359235637
val: 0.435594   test: 0.449668

Epoch: 68
Loss: 0.11815496698191
val: 0.440811   test: 0.456065

Epoch: 69
Loss: 0.13404194416895196
val: 0.489754   test: 0.533839

Epoch: 70
Loss: 0.09608796220740463
val: 0.537089   test: 0.593275

Epoch: 71
Loss: 0.10191841311042969
val: 0.561824   test: 0.609887

Epoch: 72
Loss: 0.11361937400858171
val: 0.591500   test: 0.622308

Epoch: 73
Loss: 0.10490253882491668
val: 0.610691   test: 0.619335

Epoch: 74
Loss: 0.09490131772865523
val: 0.618027   test: 0.596461

Epoch: 75
Loss: 0.1156856352257912
val: 0.643737   test: 0.597033

Epoch: 76
Loss: 0.09388995478264883
val: 0.652374   test: 0.595091

Epoch: 77
Loss: 0.10811315394162796
val: 0.610649   test: 0.574688

Epoch: 78
Loss: 0.10319929823191863
val: 0.581131   test: 0.597745

Epoch: 79
Loss: 0.10651510424792467
val: 0.609778   test: 0.595701

Epoch: 80
Loss: 0.11681465909044744
val: 0.635403   test: 0.608692

Epoch: 81
Loss: 0.0880049808133542
val: 0.606734   test: 0.588925

Epoch: 82
Loss: 0.10381243157490846
val: 0.573224   test: 0.583282

Epoch: 83
Loss: 0.08876446211393318
val: 0.564409   test: 0.588673

Epoch: 84
Loss: 0.09330504039850987
val: 0.560004   test: 0.597551

Epoch: 85
Loss: 0.08981156514877242
val: 0.557728   test: 0.602097

Epoch: 86
Loss: 0.10622473920822918
val: 0.566855   test: 0.599331

Epoch: 87
Loss: 0.07594522236525558
val: 0.644541   test: 0.594666

Epoch: 88
Loss: 0.07898343053583962
val: 0.656333   test: 0.596486

Epoch: 89
Loss: 0.10304699343774959
val: 0.663899   test: 0.598353

Epoch: 90
Loss: 0.09982285622781926
val: 0.647656   test: 0.589929

Epoch: 91
Loss: 0.08477227768232437
val: 0.637571   test: 0.601520

Epoch: 92
Loss: 0.08859610125810924
val: 0.656620   test: 0.615835

Epoch: 93
Loss: 0.08263467087315893
val: 0.627936   test: 0.593549

Epoch: 94
Loss: 0.09459970997364853
val: 0.599300   test: 0.580668

Epoch: 95
Loss: 0.07965596209709733
val: 0.632426   test: 0.583793

Epoch: 96
Loss: 0.06998769364210963
val: 0.640449   test: 0.580934

Epoch: 97
Loss: 0.08737638635008234
val: 0.634919   test: 0.576853

Epoch: 98
Loss: 0.09454356716638188
val: 0.638173   test: 0.580184

Epoch: 99
Loss: 0.07616965204530123
val: 0.643492   test: 0.600817

Epoch: 100
Loss: 0.0683504562626202
val: 0.633123   test: 0.627766

[10]:
print('best val: {:.6f}\ttest: {:.6f}'.format(val_roc_list[best_val_idx], test_roc_list[best_val_idx]))
best val: 0.663913      test: 0.664687

6. Gather the results from 5 independent runs.

Here we omit the results from the other runs and provid the integrated results from three classfication tasks (HIV, MUV and PCBA).

[19]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit

hiv = np.array([
    [0.5896, 0.6812, 0.7460, 0.7774, 0.7768, 0.7998, 0.8310, 0.8411, 0.8531],
    [0.6647, 0.7389, 0.7801, 0.7874, 0.8040, 0.8039, 0.8279, 0.8237, 0.8710],
    [0.6258, 0.7037, 0.7090, 0.8054, 0.7857, 0.7734, 0.8187, 0.8396, 0.8412],
    [0.6458, 0.6869, 0.7464, 0.7768, 0.7636, 0.7932, 0.8067, 0.8415, 0.8501],
    [0.6432, 0.6805, 0.7669, 0.7776, 0.8186, 0.7801, 0.7892, 0.8428, 0.8326]
])
muv = np.array([
    [0.4381, 0.5454, 0.5818, 0.6727, 0.7621, 0.6712, 0.7054, 0.7617, 0.8068],
    [0.5378, 0.6388, 0.6178, 0.7187, 0.7038, 0.7215, 0.7386, 0.7834, 0.8024],
    [0.4707, 0.5556, 0.5380, 0.6610, 0.6668, 0.7252, 0.7222, 0.7891, 0.8183],
    [0.4498, 0.5724, 0.5653, 0.6855, 0.6907, 0.6618, 0.7230, 0.7255, 0.7780],
    [0.4755, 0.6006, 0.6749, 0.7079, 0.7379, 0.7317, 0.7554, 0.8083, 0.7814]
])
pcba = np.array([
    [0.0524, 0.0958, 0.1296, 0.1750, 0.1961, 0.2161, 0.2468, 0.2673, 0.2829],
    [0.0502, 0.0942, 0.1339, 0.1823, 0.2087, 0.2303, 0.2534, 0.2733, 0.2809],
    [0.0526, 0.0961, 0.1339, 0.1691, 0.1912, 0.2161, 0.2443, 0.2625, 0.2735],
    [0.0512, 0.0952, 0.1302, 0.1836, 0.2048, 0.2203, 0.2551, 0.2662, 0.2791],
    [0.0542, 0.0995, 0.1327, 0.1775, 0.2000, 0.2262, 0.2448, 0.2713, 0.2794]
])
data_list = [hiv, muv, pcba]

7. Fit the curve and plot the neural scaling law.

[24]:
def power_law(x, a, b):
    return a * np.power(x, b)

fig, axs = plt.subplots(2, 3, figsize=(15, 7))
x = [0.01,0.05,0.1,0.2,0.3,0.4,0.6,0.8,1]

for i,data in enumerate(data_list):
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)
    params, _ = curve_fit(power_law, x, mean)

    mean_data = np.mean(data, axis=0)
    std_data = np.std(data, axis=0) / 2
    max_data = mean_data + std_data
    min_data = mean_data - std_data

    axs[0][i].fill_between(x, max_data, min_data, alpha=0.3)
    axs[0][i].scatter(x, mean_data, marker='o',color='royalblue',s=30)
    axs[0][i].plot(x, power_law(x, params[0], params[1]), linewidth=1.5, c='blue')

    axs[1][i].fill_between(x, max_data, min_data, alpha=0.3)
    axs[1][i].scatter(x, mean_data, marker='o',color='royalblue',s=30)
    axs[1][i].plot(x, power_law(x, params[0], params[1]), linewidth=1.5, c='blue')

    axs[0][i].set_xlabel('Select Ratio')
    axs[0][i].set_ylabel('ROC-AUC')
    axs[0][i].set_xticks(x)
    axs[0][i].set_xticklabels(['1%',None,'10%',None,'30%',None,'60%',None,'100%'])

    axs[1][i].set_xlabel('Select Ratio')
    axs[1][i].set_ylabel('ROC-AUC')
    axs[1][i].set_xticks(np.log(x))
    axs[1][i].set_xticklabels(['1%',None,'10%',None,'30%',None,'60%',None,'100%'])
    axs[1][i].set_xscale('log')
    axs[1][i].set_yscale('log')


axs[0][2].set_ylabel('AP')
axs[1][2].set_ylabel('AP')
axs[0][0].set_title('(a) HIV')
axs[0][1].set_title('(b) MUV')
axs[0][2].set_title('(c) PCBA')

axs[1][0].set_title('(d) HIV (log-log)')
axs[1][1].set_title('(e) MUV (log-log)')
axs[1][2].set_title('(f) PCBA (log-log)')

fig.tight_layout()
plt.show()
_images/example_18_0.png