PyTorch: Prop3D with Language Models (ESM)
Install prereqs: pytorch and huggingface transformers
Uncomment if you need to install. For PyTorch GPU installation, follow the instructions on https://pytorch.org/get-started/locally/
[ ]:
import sys
[ ]:
#!{sys.executable} -m pip install --user torch
[ ]:
#!{sys.executable} -m pip install --user tokenizers transformers
Imports
[ ]:
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, EsmForTokenClassification, DataCollatorForTokenClassification
from Prop3D.ml.datasets.DistributedDomainSequenceDataset import DistributedDomainSequenceDataset
torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
Define parameters
[ ]:
os.environ["HS_ENDPOINT"] = "http://prop3d-hsds.pods.uvarc.io"
os.environ["HS_USERNAME"] = "None"
os.environ["HS_PASSWORD"] = "None"
cath_file = "/CATH/Prop3D-20.h5"
cath_superfamily = "1/10/10/10" #Use / instead of .
#Could be charge, hydrophobicity, accessibility, 3 types of secondary structure, etc
predict_features = ["is_sheet", "is_helix", "Unk_SS"]
Set up ESM
[ ]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForTokenClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=len(predict_features))
[ ]:
data_collator = DataCollatorForTokenClassification(tokenizer)
def collate(x):
sequences, labels = zip(*x)
batch = []
for s, l in x:
s = tokenizer(s)
s["labels"] = np.argmax(l, axis=1)
batch.append(s)
batch = data_collator(batch)
batch["input_ids"].to(device)
batch["attention_mask"].to(device)
batch["labels"].to(device)
return batch
Set up Prop3D datasets and dataloaders
[ ]:
dataset_train = DistributedDomainSequenceDataset(
cath_file,
cath_superfamily,
predict_features=predict_features,
cluster_level="S100")
training_loader = torch.utils.data.DataLoader(
dataset_train,
batch_size=128,
collate_fn=collate,
shuffle=True)
dataset_val = DistributedDomainSequenceDataset(
cath_file,
cath_superfamily,
predict_features=predict_features,
cluster_level="S100",
validation=True)
val_loader = torch.utils.data.DataLoader(
dataset_val,
batch_size=128,
collate_fn=collate,
shuffle=False)
Start training
[ ]:
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
[ ]:
for epoch in range(30):
for loader, is_train in [(training_loader, True), (val_loader, False)]:
running_loss = 0
pbar = tqdm(loader)
for batch in pbar: #enumerate(loader):
# Every data instance is an input + label pair
#inputs, labels = data
#labels = labels.to(device)
#inputs = tokenizer(inputs).to(device)
# Zero your gradients for every batch!
optimizer.zero_grad()
if is_train:
# Make predictions for this batch
outputs = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"])
# Compute the loss and its gradients
loss = outputs.loss
loss.backward()
# Adjust learning weights
optimizer.step()
name = "TRAIN"
else:
# Make predictions for this batch
outputs = model(inputs, labels=labels)
# Compute the loss and its gradients
loss = outputs.loss
name = "VALIDATION"
pbar.set_description(f"Epoch {epoch} {name} Loss {loss}")