Robin-JE-Bradfield's homepage

Bookmark this to keep an eye on my project updates!

Gene Perturbation Correlate Predictor

This project is intended to use data from the Cell x Gene Census, or from other sources, to train a model to predict how expression levels of genes might change given an initial expression state and a perturbation of the expression levels of one or more genes. See the repository for associated code and other files.

The current version uses an encoder-decoder architecture; it encodes the initial expression levels using self-attention in the encoder module, then carries out cross-attention in the decoder module with the perturbations as the queries and the encoded initial states as the keys and values. Expression levels are embedded by converting gene Ensembl IDs to numeric tokens, using PyTorch’s nn.Embedding to produce a learnable embedding for each gene, then multiplying that embedding vector by the expression level of the associated gene, expressed as log(TPM + 1). Data is generated by pairing up single-cell expression data drawn from the Census and generating a perturbation between members of each pair.

Importantly, this model is a toy. I have not built it with the expectation of it actually being able to solve real problems in practice. Its inspiration was the idea of producing a model capable of predicting the final expression state of a cell after the expression level of one or more genes was directly changed, such as by RNA interference, gene editing, or the effect of an engineered signalling pathway, for use in synthetic biology. No matter how good the model itself became, the nature of its training data means that this is impossible; it could only ever learn correlations between expression levels, not causations, since the ‘perturbations’ it is trained on are calculated from the expression states of two unrelated cells rather than being actual perturbations imposed on a single cell and the states being the expression levels before and after. The purpose of this model is to help me develop my skills with machine learning (and indeed I have learned quite a bit from writing it, trying to solve a semi-realistic problem with deep learning is a useful way to refine and consolidate textbook knowledge), following on from having worked through the D2L Dive Into Deep Learning textbook. If I can make it succeed in learning anything at all (beyond simple memorization of the training set) I will be quite happy!

The current version does not yet reach that goal; the loss fails to decrease noticeably after five epochs (training results uploaded for interest). This is likely mostly due to the quality of the data. At present, the first hundred protein-coding genes in the Census are selected for inclusion in the model, without consideration of whether they are expressed or measured in the data; my code includes features that mask out genes not measured from attention and loss, but they still take up space that could be used for meaningful data, and based on examples drawn from the dataloader most genes are either not measured or not expressed. This likely means that the model simply lacks enough data to learn any reliable patterns. However, other possible causes include the learning rate used, or an as-yet-unknown error in my code or approach. Possible next steps include the following:

On the off chance that you are looking at this and find it interesting, you can contact me at rjebradfield@gmail.com with comments or to start a discussion.