PyTorch: Prop3D with Graphs (ProteinMPNN)

Here we should how to use Prop3D in a PyTorch model to predict the electrostatic protential using ProteinMPNN.

Install preqrequites if needed

Uncomment to install

[ ]:
#!git clone https://github.com/dauparas/ProteinMPNN.git

Define imports

[ ]:

import os import sys sys.path.append("ProteinMPNN/training") import torch from torch import nn from Prop3D.ml.datasets.DistributedProteinMPNNDataset import DistributedProteinMPNNDataset from model_utils import ProteinMPNN, featurize, get_std_opt, loss_nll torch.manual_seed(0) device = "cuda" if torch.cuda.is_available() else "cpu"

Defined MPNN model

Instead of predicting 21 characters, only predict 3: is_electronegative, is_electropostive, is_neutral

Define parameters

[ ]:
os.environ["HS_ENDPOINT"] = "http://prop3d-hsds.pods.uvarc.io"
os.environ["HS_USERNAME"] = "None"
os.environ["HS_PASSWORD"] = "None"

cath_file = "/CATH/Prop3D-20.h5"
cath_superfamily = "2/60/40/10" #Use / instead of .

#Could be charge, hydrophobicity, accessibility, 3 types of secondary structure, etc
predict_features = ["electrostatic_potential"]
[ ]:
def collate(x):
    return x
[ ]:
dataset_train = DistributedProteinMPNNDataset(
    cath_file,
    cath_superfamily,
    predict_features=predict_features,
    cluster_level="S100")
training_loader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=16,
    shuffle=True,
    num_workers=64,
    collate_fn=collate)
dataset_val = DistributedProteinMPNNDataset(
    cath_file,
    cath_superfamily,
    predict_features=predict_features,
    cluster_level="S100",
    validation=True)
val_loader = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=16,
    shuffle=False,
    num_workers=64,
    collate_fn=collate)
[ ]:
charge_to_idx = {(1,0,0):0, (0,1,0):1, (0,0,1):2}
def process_batch(batch):
    """Convert featuress into a new type of sequence with L=3
    """
    X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device)
    for i, prot in enumerate(batch):
        for j, value in enumerate(prot["prop3d_features"]):
            try:
                S[i,j] = charge_to_idx[(value==0,value<0,value>0)]
            except KeyError:
                S[i,j] = 0.
    return X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all
[ ]:
def loss_smoothed(S, log_probs, mask, weight=0.1):
    """ Negative log probabilities """
    S_onehot = torch.nn.functional.one_hot(S, 3).float()

    # Label smoothing
    S_onehot = S_onehot + weight / float(S_onehot.size(-1))
    S_onehot = S_onehot / S_onehot.sum(-1, keepdim=True)

    loss = -(S_onehot * log_probs).sum(-1)
    loss_av = torch.sum(loss * mask) / 2000.0 #fixed
    return loss, loss_av
[ ]:
model = ProteinMPNN(num_letters=3, vocab=3)
model = model.to(device)
optimizer = get_std_opt(model.parameters(), 128, 0)
[ ]:
for epoch in range(200):
    for loader, is_train in [(training_loader, True), (val_loader, False)]:
        if is_train:
            model.train()
        else:
            model.eval()

        pbar = tqdm(loader)
        for data in pbar:
            X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = process_batch(data)

            if is_train:
                # Zero your gradients for every batch!
                optimizer.zero_grad()

            # Make predictions for this batch
            mask_for_loss = mask*chain_M

            log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all)
            _, loss_av_smoothed = loss_smoothed(S, log_probs, mask_for_loss)

            if is_train:
                loss_av_smoothed.backward()

                # Adjust learning weights
                optimizer.step()

                name = "TRAIN"

            else:
                name = "VALIDATION"

            loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss)

            pbar.set_description(f"Epoch {epoch} {name} Loss {loss.mean()}")