PyTorch: Prop3D with Graphs (ProteinMPNN)
Here we should how to use Prop3D in a PyTorch model to predict the electrostatic protential using ProteinMPNN.
Install preqrequites if needed
Uncomment to install
[ ]:
#!git clone https://github.com/dauparas/ProteinMPNN.git
Define imports
[ ]:
import os
import sys
sys.path.append("ProteinMPNN/training")
import torch
from torch import nn
from Prop3D.ml.datasets.DistributedProteinMPNNDataset import DistributedProteinMPNNDataset
from model_utils import ProteinMPNN, featurize, get_std_opt, loss_nll
torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
Defined MPNN model
Instead of predicting 21 characters, only predict 3: is_electronegative, is_electropostive, is_neutral
Define parameters
[ ]:
os.environ["HS_ENDPOINT"] = "http://prop3d-hsds.pods.uvarc.io"
os.environ["HS_USERNAME"] = "None"
os.environ["HS_PASSWORD"] = "None"
cath_file = "/CATH/Prop3D-20.h5"
cath_superfamily = "2/60/40/10" #Use / instead of .
#Could be charge, hydrophobicity, accessibility, 3 types of secondary structure, etc
predict_features = ["electrostatic_potential"]
[ ]:
def collate(x):
return x
[ ]:
dataset_train = DistributedProteinMPNNDataset(
cath_file,
cath_superfamily,
predict_features=predict_features,
cluster_level="S100")
training_loader = torch.utils.data.DataLoader(
dataset_train,
batch_size=16,
shuffle=True,
num_workers=64,
collate_fn=collate)
dataset_val = DistributedProteinMPNNDataset(
cath_file,
cath_superfamily,
predict_features=predict_features,
cluster_level="S100",
validation=True)
val_loader = torch.utils.data.DataLoader(
dataset_val,
batch_size=16,
shuffle=False,
num_workers=64,
collate_fn=collate)
[ ]:
charge_to_idx = {(1,0,0):0, (0,1,0):1, (0,0,1):2}
def process_batch(batch):
"""Convert featuress into a new type of sequence with L=3
"""
X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device)
for i, prot in enumerate(batch):
for j, value in enumerate(prot["prop3d_features"]):
try:
S[i,j] = charge_to_idx[(value==0,value<0,value>0)]
except KeyError:
S[i,j] = 0.
return X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all
[ ]:
def loss_smoothed(S, log_probs, mask, weight=0.1):
""" Negative log probabilities """
S_onehot = torch.nn.functional.one_hot(S, 3).float()
# Label smoothing
S_onehot = S_onehot + weight / float(S_onehot.size(-1))
S_onehot = S_onehot / S_onehot.sum(-1, keepdim=True)
loss = -(S_onehot * log_probs).sum(-1)
loss_av = torch.sum(loss * mask) / 2000.0 #fixed
return loss, loss_av
[ ]:
model = ProteinMPNN(num_letters=3, vocab=3)
model = model.to(device)
optimizer = get_std_opt(model.parameters(), 128, 0)
[ ]:
for epoch in range(200):
for loader, is_train in [(training_loader, True), (val_loader, False)]:
if is_train:
model.train()
else:
model.eval()
pbar = tqdm(loader)
for data in pbar:
X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = process_batch(data)
if is_train:
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
mask_for_loss = mask*chain_M
log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all)
_, loss_av_smoothed = loss_smoothed(S, log_probs, mask_for_loss)
if is_train:
loss_av_smoothed.backward()
# Adjust learning weights
optimizer.step()
name = "TRAIN"
else:
name = "VALIDATION"
loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss)
pbar.set_description(f"Epoch {epoch} {name} Loss {loss.mean()}")