PyTorch: Prop3D with Language Models (ESM)

Install prereqs: pytorch and huggingface transformers

Uncomment if you need to install. For PyTorch GPU installation, follow the instructions on https://pytorch.org/get-started/locally/

[ ]:
import sys
[ ]:
#!{sys.executable} -m pip install --user torch
[ ]:
#!{sys.executable} -m pip install --user tokenizers transformers

Imports

[ ]:
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, EsmForTokenClassification, DataCollatorForTokenClassification
from Prop3D.ml.datasets.DistributedDomainSequenceDataset import DistributedDomainSequenceDataset

torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"

Define parameters

[ ]:
os.environ["HS_ENDPOINT"] = "http://prop3d-hsds.pods.uvarc.io"
os.environ["HS_USERNAME"] = "None"
os.environ["HS_PASSWORD"] = "None"

cath_file = "/CATH/Prop3D-20.h5"
cath_superfamily = "1/10/10/10" #Use / instead of .

#Could be charge, hydrophobicity, accessibility, 3 types of secondary structure, etc
predict_features = ["is_sheet", "is_helix", "Unk_SS"]

Set up ESM

[ ]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForTokenClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=len(predict_features))
[ ]:
data_collator = DataCollatorForTokenClassification(tokenizer)
def collate(x):
    sequences, labels = zip(*x)
    batch = []
    for s, l in x:
        s = tokenizer(s)
        s["labels"] = np.argmax(l, axis=1)
        batch.append(s)

    batch = data_collator(batch)
    batch["input_ids"].to(device)
    batch["attention_mask"].to(device)
    batch["labels"].to(device)

    return batch

Set up Prop3D datasets and dataloaders

[ ]:
dataset_train = DistributedDomainSequenceDataset(
    cath_file,
    cath_superfamily,
    predict_features=predict_features,
    cluster_level="S100")
training_loader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=128,
    collate_fn=collate,
    shuffle=True)
dataset_val = DistributedDomainSequenceDataset(
    cath_file,
    cath_superfamily,
    predict_features=predict_features,
    cluster_level="S100",
    validation=True)
val_loader = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=128,
    collate_fn=collate,
    shuffle=False)

Start training

[ ]:
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
[ ]:
for epoch in range(30):
    for loader, is_train in [(training_loader, True), (val_loader, False)]:
        running_loss = 0
        pbar = tqdm(loader)
        for batch in pbar: #enumerate(loader):
            # Every data instance is an input + label pair
            #inputs, labels = data
            #labels = labels.to(device)
            #inputs = tokenizer(inputs).to(device)

            # Zero your gradients for every batch!
            optimizer.zero_grad()

            if is_train:
                # Make predictions for this batch

                outputs = model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"])

                # Compute the loss and its gradients
                loss = outputs.loss
                loss.backward()

                # Adjust learning weights
                optimizer.step()

                name = "TRAIN"

            else:
                # Make predictions for this batch
                outputs = model(inputs, labels=labels)

                # Compute the loss and its gradients
                loss = outputs.loss

                name = "VALIDATION"

            pbar.set_description(f"Epoch {epoch} {name} Loss {loss}")