With Covid-19 ravaging the world, a lot of people are exploring ways AI and ML can help in combating the virus spread and infection. Viruses like Covid-19 is a complex socio-economic and public health problem and the solutions cut across many disciplines. In this post, the focus is on a very specific problem related to testing. Typical testing policies like testing en masse or testing whoever wants to get tested are not very effective. Wouldn’t it be great if we could predict the probability of viral infection for anyone based on recent contact history and then use the result to judiciously decide who should be tested?
In this post, we will go through a solution where personal contact data treated as sequence along with infection label gets used to train a Long Short Term Memory(LSTM) network. The trained model can be deployed to predict the probability of infection for anyone with contact data. This prediction data could be used by healthcare authorities to select the people who should be tested in a data-driven way. The Python implementation based on PyTorch is available in my open-source GitHub repository avenir.
I am not a viral infection expert. But based on my personal studies, the infection process roughly works as follows. With every contact with the outside world, depending on the type of contact and exposure, a certain quantity of virus enters your body. Once inside they reproduce and the growth approximately follows an S or sigmoid curve.
Under attack, your body will activate its innate immune system. If it’s a known, previously encountered virus, the body will activate virus-specific antibodies. Once the viral load reaches a tipping point, symptoms start showing up and you have a full-blown infection. Either you recover or succumb to the virus, depending on which side in the war between the virus and the immune system wins.
At a macro level, there are 3 ways to combat viral infection spread as follows. Science and technology can make significant contributions in the first 2 areas. Herd immunity is a natural process and it grows with time and any kind of intervention won’t help
- Testing and quanrantin
- Herd immunity
Sequential Contact Data
In my model, each contact incident has the following attributes. Potentially there could be more. Essentially contact data can be thought of as multidimensional time series data
- Elapsed time in days from past when the contact incident happened
- Level of exposure. There are 4 of those. More on them later
- Whether wearing a mask or not (boolean)
- Vulnerability based on age and pre-existing condition (boolean)
- How pervasive the virus spread is in the geographical area. There 3 levels.
It’s been reported that the nexus of the following 3 conditions pose the highest risk of infection. The 4 levels of exposure is based on how many of these conditions are true for a contact incident.
- Crowded place
- Covered place
- People talking loudly
Vulnerability is based on people being above a certain age and / or people with pre-existing health conditions. There are the levels of infection intensity in some geographical areas, with 1 being the low and 3 being the high.
All the contact data listed here could be collected and curated in a real-world scenario. The contact data could be embellished with more attributes e.g. duration of the contact. It makes a difference whether someone is in a crowded place for 10 minutes or 1 hour.
Labeling the Training Data
Since the training data is synthetically generated, I used heuristics to label the data. The logic goes as follows.
- Each contact incident is mapped to some point in time in the sigmoid viral growth curve, which corresponds to some quantity of virus intake
- The viral intake quantity is gown as per sigmoid up to the current time
- All the viral loads from all the contact incidents are summed.
- If the sum is above some threshold, with the high probability the person is labeled as having the infection
In a real-world scenario, after collecting the contact data if a person is tested positive he or she will be labeled as having the infection.
As alluded to earlier, contact data is essentially time-based sequence data and so recurrent network is the natural choice fo modeling such data. In the recurrent network, the output and the current hidden state is a function of the current input and the previous hidden state. Unlike regular recurrent neural network(RNN), LSTM can handle long term memory very well but has a more complex architecture. There is a lot of online content on the inner working of LSTM. The blog I just cited has a very succinct and lucid introduction to LSTM.
The artifacts of an LSTM network are listed below. Input is a sequence of vector or scalar. The output is a sequence of scalar or just a scalar, depending on the type of problem being solved.
The input gate controls what parts of the current input will be saved in long term memory i.e the cell state. The forget gate control what parts of the incoming long term memory will be retained in the current long term memory. The output gate controls which parts of the current lag term memory will retain to generate the output and the outgoing hidden state.
- Cell state (long term memory)
- Hidden state (short term memory)
- Input gate
- Forget gate
- Output gate
There are two kinds of sequence modeling problems as below. Our problem belongs to the second category
- Sequence to sequence e.g language translation
- Sequence to scalar e.g time series forecasting
For a sequence to sequence modeling as in language translation, typically has one or more layers for decoding.
I have created a Python wrapper class around PyTorch implementation of LSTM to make it easier to use with the help of an elaborate configuration parameter file. By appropriately editing the sample configuration file you will be able to create and train any LSTM network
Training LSTM Model
The LSTM architecture used to train has the following attributes. Adding a second layer gave a big boost to the performance . Adding dropout gave a further boost. While adding dropout, I had to increase the number of hidden units.
- Two layers and unidirectional
- Adam optimizer (can be changed through configuration)
- Softmax activation for output (can be changed through configuration)
- Sequence length of 5 and input vector size of 5
- Hidden unit size of 100 (can be changed through configuration)
- Dropout probability of 0.5
The driver python script can be used to generate training data and train the model. Details can be found in the tutorial document. I have done some manual parameter tuning and ended up with a recall of around 82% on test data. This is pretty impressive, considering the fact that in the synthetic data generation process 10% error was added.
Among the various performance metrics, recall is most appropriate because we want to minimize false-negative i.e a truly infected person is predicted as not being infected. Here the tail end of the training console output, showing loss during training, predictions for validation data, and final recall value for validation data. The performance metric can be chosen through the configuration file.
epoch 150 batch 0 loss 0.313277 epoch 150 batch 10 loss 0.316564 epoch 150 batch 20 loss 0.344409 epoch 150 batch 30 loss 0.313269 ..validating model predicted actual [1.000000e+00 6.363314e-12] 0 [2.7581002e-14 1.0000000e+00] 1 [1.000000e+00 2.515283e-09] 0 ................................. [1.0000000e+00 3.6647692e-11] 0 [1.000000e+00 7.987434e-11] 0 [0.5680177 0.43198228] 0 [1.0614327e-08 1.0000000e+00] 1 [7.313520e-07 9.999993e-01] 1 [1.0000000e+00 4.4309636e-11] 0 [3.8157217e-05 9.9996185e-01] 1 perf score 0.814 ..saving model checkpoint model saved
In the real world scenario, contact data to be used for training the model should be collected when people get tested.
Prediction can be either binary or probabilistic, selectable through a configuration parameter. Probabilistic prediction is more useful here because public healthcare officials can set a probability threshold to select the people to be tested and/or placed in-home quarantine. Here is some sample prediction output. It shows the data followed by the infection probability.
14,1,1,1,1,12,0,1,1,1,10,3,1,1,1,9,0,1,1,1,5,1,1,1,1 0.000 14,0,0,1,1,11,2,0,1,1,11,1,0,1,1,7,0,0,1,1,5,1,0,1,1 0.001 15,1,1,1,1,13,3,1,1,1,11,0,1,1,1,6,0,1,1,1,2,2,1,1,1 0.818 15,0,1,1,1,15,2,1,1,1,11,0,1,1,1,6,4,1,1,1,3,0,0,1,1 0.024 14,0,1,1,1,12,2,0,1,1,7,1,0,1,1,5,1,0,1,1,3,0,0,1,1 0.000 8,2,1,0,2,7,2,1,0,2,4,1,1,0,2,3,1,1,0,2,2,1,1,0,2 1.000
If the prediction service is deployed as a mobile app, then people can provide their contact data in the last 15 days and get a prediction right away.
For using a tool like this in the real world, the following procedure could to be followed. As people come to take the test, their contact data should be collected and curated. When the test results become available, the data should be labeled. When data has been collected, the model should be trained and made available for use.
When testing resources and tools is scarce the model could be used to test only high-risk people. Before performing any test, the model could be used to predict the probability of infection. Only people with a high probability of infection could be administered the test.
When someone with high probability of infection has been found, other people who have been contacted could be found through contact tracing. Their contact data could be harvested and run through the prediction process.
We have seen an example of how deep learning technology can be harnessed to help in combating viral infection spread including. We have also seen the capability of LSTM in modeling complex non-linear relationship involving sequence data.