Source code for genevector.model

"""GeneVector neural embedding model for gene co-expression learning."""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import numpy
import matplotlib.pyplot as plt
from .embedding import GeneEmbedding
from ._logging import get_logger

logger = get_logger(__name__)

[docs] class GeneVectorModel(nn.Module): """ GeneVector PyTorch model. :param dataset: num_embeddings. :type dataset: GeneVector.dataset.GeneVectorDataset :param output_file: Flat file to store gene embedding. Input weights and output weights stored in with "2" suffix. :type output_file: str :param emb_dimension: Number of hidden units and dimension of latent representation. :type output_file: int :param batch_size: Size to batch gene pairs, defaults to all gene pairs. :type output_file: int or None (default). :param gain: Scale factor of orthogonal weight initialization. :type gain: int :param device: Sets Torch device ("cpu", "cuda:0", "mps") :type device: str """
[docs] def __init__(self, num_embeddings, embedding_dim, gain=1., init_ortho=True): """Initialize the embedding model. Parameters ---------- num_embeddings : int Number of genes (vocabulary size). embedding_dim : int Dimension of gene embedding vectors. gain : float Scale factor for orthogonal weight initialization. init_ortho : bool If True, use orthogonal initialization. Otherwise uniform(-1, 1). """ self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim super(GeneVectorModel, self).__init__() self.wi = nn.Embedding(num_embeddings, embedding_dim) self.wj = nn.Embedding(num_embeddings, embedding_dim) if init_ortho: nn.init.orthogonal_(self.wi.weight, gain=gain) nn.init.orthogonal_(self.wj.weight, gain=gain) else: self.wi.weight.data.uniform_(-1,1) self.wj.weight.data.uniform_(-1,1)
[docs] def forward(self, i_indices, j_indices): """Compute dot product between gene embedding pairs. Parameters ---------- i_indices : torch.LongTensor Indices for first gene in each pair. j_indices : torch.LongTensor Indices for second gene in each pair. Returns ------- torch.Tensor Dot product scores for each gene pair. """ w_i = self.wi(i_indices) w_j = self.wj(j_indices) x = torch.sum(w_i * w_j, dim=1) return x
[docs] def save_embedding(self, id2word, file_name, layer): """Save embedding weights to a .vec text file. Parameters ---------- id2word : dict Mapping from gene index to gene symbol. file_name : str Output file path. layer : int 0 for input weights (wi), 1 for output weights (wj). """ if layer == 0: embedding = self.wi.weight.cpu().data.numpy() else: embedding = self.wj.weight.cpu().data.numpy() with open(file_name, 'w') as f: f.write('%d %d\n' % (len(id2word), self.embedding_dim)) for wid, w in id2word.items(): e = ' '.join(map(lambda x: str(x), embedding[wid])) f.write('%s %s\n' % (w, e))
[docs] class GeneVector(object): """ GeneVector framework for training a gene embedding. :param dataset: GeneVector dataset. :type dataset: GeneVector.dataset.GeneVectorDataset :param output_file: Flat file to store gene embedding. Input weights and output weights stored in with "2" suffix. :type output_file: str :param emb_dimension: Number of hidden units and dimension of latent representation. :type output_file: int :param batch_size: Size to batch gene pairs, defaults to all gene pairs. :type output_file: int or None (default). :param gain: Scale factor of orthogonal weight initialization. :type gain: int :param device: Sets Torch device ("cpu", "cuda:0", "mps") :type device: str """
[docs] def __init__(self, dataset, output_file, emb_dimension=100, batch_size=None, gain=1, c=100., device="cpu", init_ortho=False): """ Constructor method """ self.dataset = dataset self.init_ortho = init_ortho self.dataset.create_inputs_outputs(c=c) self.output_file_name = output_file self.emb_size = len(self.dataset.data.gene2id) self.emb_dimension = emb_dimension if batch_size == None and self.dataset.num_pairs: self.batch_size = self.dataset.num_pairs elif batch_size != None: self.batch_size = batch_size else: self.batch_size = 1e6 self.use_cuda = torch.cuda.is_available() self.model = GeneVectorModel(self.emb_size, self.emb_dimension, gain=gain, init_ortho=init_ortho) self.device = device if self.device == "cuda" and not self.use_cuda: raise ValueError("CUDA requested but no GPU available.") elif self.device == "cuda": self.model.cuda() self.optimizer = optim.Adadelta(self.model.parameters()) self.loss = nn.MSELoss() self.epoch = 0 self.loss_values = list() self.mean_loss_values = []
[docs] def train(self, epochs, threshold=None, update_interval=20, alpha=0.0, beta=0.0): """ Trains the model for the specified number of epochs or until the loss falls below the threshold. :param epchs: Maximum number of epochs. :type epochs: int :param threshold: Stopping critera. :type threshold: float :param update_interval: Number of epochs between printing loss to stdout. :type update_interval: int :param alpha: Coefficient of orthogonality penalty. :type alpha: float :param beta: Coefficient of magnitude scaling. :type beta: float """ last_loss = 0. for _ in range(1, epochs+1): batch_i = 0 for x_ij, i_idx, j_idx in self.dataset.get_batches(self.batch_size): batch_i += 1 outputs = self.model(i_idx, j_idx) loss = self.loss(outputs, x_ij) w1 = self.model.wi.weight w2 = self.model.wj.weight #STEP2 wTw = torch.matmul(w1, w2.t()) wTw.fill_diagonal_(0) t1 = (wTw ** 2).sum() t1 = alpha * t1 #STEP3 wTw = torch.matmul(w1, w2.t()) diag = torch.diag(wTw) t2 = (diag - self.dataset._ent) t2 = (t2 ** 2).sum() t2 = beta * t2 self.optimizer.zero_grad() loss = loss + t1 + t2 loss.backward() self.optimizer.step() self.loss_values.append(loss.item()) self.mean_loss_values.append(numpy.mean(self.loss_values[-10:])) curr_loss = numpy.mean(self.loss_values[-10:]) if self.epoch % int(update_interval) == 0: logger.info(f"Epoch {self.epoch} loss: {round(np.mean(self.loss_values[-30:]), 5)}") if type(threshold) == float and abs(curr_loss - last_loss) < threshold: logger.info("Training complete!") self.model.save_embedding(self.dataset.data.id2gene, self.output_file_name, 0) self.model.save_embedding(self.dataset.data.id2gene, self.output_file_name.replace(".vec","2.vec"), 1) return last_loss = curr_loss self.epoch += 1 logger.info("Saving model...") self.model.save_embedding(self.dataset.data.id2gene, self.output_file_name, 0) self.model.save_embedding(self.dataset.data.id2gene, self.output_file_name.replace(".vec","2.vec"), 1)
[docs] def save(self, filepath): """Save model state dict to file. Parameters ---------- filepath : str Output file path. """ torch.save(self.model.state_dict(), filepath)
[docs] def load(self, filepath): """Load model state dict from file. Parameters ---------- filepath : str Path to saved model state dict. """ self.gnn.load_state_dict(torch.load(filepath)) self.gnn.eval()
[docs] def plot(self, fname=None, log=False): """Plot training loss curve. Parameters ---------- fname : str, optional File path to save figure. log : bool If True, use log scale for x-axis. """ fig, ax = plt.subplots(1,1,figsize=(12,5),facecolor='#FFFFFF') ax.plot(self.mean_loss_values, color="purple") ax.set_ylabel('Loss') ax.set_xlabel('Epoch') if log: ax.set_xscale('log') if fname != None: fig.savefig(fname)