Over the past year, I have been working on building a graph-based paratope (antibody binding site) prediction tool – Paragraph. Fortunately, I have had moderate success with this and you can now check out the preprint of this work here.
However, for a long time, I struggled with a highly unstable network, where different random seeds yielded very different results. I believe this instability was largely due to the high class imbalance in my data – only ~10% of all residues in the Fv (variable region of the antibody) belong to the paratope.
I tried many different things in an attempt to stabilise my training, most of which failed. I will share all of these ideas with you though – successful or not – as what works for one person/network is never guaranteed to work for another. I hope that the below may provide some ideas to try out for others facing similar issues. Where possible, I also provide some example hyperparameter values that could act as sensible starting points.
- Use larger batch sizes – the more representative your batches are of your full dataset, the more likely you will step in the right direction e.g. 16, 64
- Normalise your loss – if you do not normalise your loss, you may be taking larger step sizes than you think
- Use L2 regularisation – penalises large weight values e.g. 0.01
- Include dropout – randomly masks a proportion of nodes during training to prevent overfitting e.g. 0.5, 0.8
- Try different loss measures – e.g. Binary Cross Entropy, Focal Loss
- Add a learning rate scheduler – this changes the learning rate during training e.g. exponentially decaying, OneCycle
- Use different optimisers – you might find that some optimisers are better at stepping in the right direction than others for your problem e.g. Adam, AdamW, Adamax, SGD
- Include gradient clipping – this can prevent the gradient from exploding or vanishing if you have quite a deep network e.g. 0.1, 1, 10 (you can clip by value or by norm)
- Use double-precision floats – this may also help if faced with vanishing gradients
For my problem, I found that all the above either did not help at all, or they improved the performance of some seeds but made others worse. What finally worked for me was realising that I could reduce the class imbalance of my problem by focusing my attention on just a small part of the Fv but still retain the majority of binding residues (data of interest). This change allowed me to reduce my class imbalance from 10:1 to 3:1 while retaining 92% of my positive data points.
I hope that the above ideas might prove helpful if you are also currently struggling with an unstable network. If you would like to understand more about my data preparation and network architecture then please do check out the preprint or GitHub repo!