Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prototype torch #372

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions covasim/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
'''

import numpy as np
import torch
import pandas as pd
import sciris as sc
import datetime as dt
Expand Down Expand Up @@ -1264,15 +1265,15 @@ def add_contacts(self, contacts, lkey=None, beta=None):
if beta is None:
beta = 1.0
beta = cvd.default_float(beta)
new_layer['beta'] = np.ones(n, dtype=cvd.default_float)*beta
new_layer['beta'] = torch.ones(n, dtype=torch.float32, device="cuda")*beta

# Create the layer if it doesn't yet exist
if lkey not in self.contacts:
self.contacts[lkey] = Layer(label=lkey)

# Actually include them, and update properties if supplied
for col in self.contacts[lkey].keys(): # Loop over the supplied columns
self.contacts[lkey][col] = np.concatenate([self.contacts[lkey][col], new_layer[col]])
self.contacts[lkey][col] = torch.cat((self.contacts[lkey][col], new_layer[col]))
self.contacts[lkey].validate()

return
Expand Down Expand Up @@ -1307,7 +1308,7 @@ def make_edgelist(self, contacts):
for lkey in lkeys:
new_layer = Layer(label=lkey)
for ckey,value in new_contacts[lkey].items():
new_layer[ckey] = np.array(value, dtype=new_layer.meta[ckey])
new_layer[ckey] = torch.tensor(value, device='cuda')
new_contacts[lkey] = new_layer

return new_contacts
Expand Down Expand Up @@ -1501,7 +1502,7 @@ def __init__(self, label=None, **kwargs):

# Initialize the keys of the layers
for key,dtype in self.meta.items():
self[key] = np.empty((0,), dtype=dtype)
self[key] = torch.empty((0,), dtype=torch.float32, device='cuda')

# Set data, if provided
for key,value in kwargs.items():
Expand Down Expand Up @@ -1558,7 +1559,7 @@ def validate(self):
for key,dtype in self.meta.items():
if dtype:
actual = self[key].dtype
expected = dtype
expected = torch.float32
if actual != expected:
errormsg = f'Expecting dtype "{expected}" for layer key "{key}"; got "{actual}"'
raise TypeError(errormsg)
Expand Down
8 changes: 4 additions & 4 deletions covasim/people.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,10 @@ def infect(self, inds, hosp_max=None, icu_max=None, source=None, layer=None, var
if len(inds) == 0:
return 0

# Remove duplicates
inds, unique = np.unique(inds, return_index=True)
if source is not None:
source = source[unique]
# # Remove duplicates
# inds, unique = np.unique(inds, return_index=True)
# if source is not None:
# source = source[unique]

# Keep only susceptibles
keep = self.susceptible[inds] # Unique indices in inds and source that are also susceptible
Expand Down
7 changes: 4 additions & 3 deletions covasim/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#%% Imports
import numpy as np
import torch
import pandas as pd
import sciris as sc
from . import utils as cvu
Expand All @@ -17,6 +18,7 @@
from . import immunity as cvimm
from . import analysis as cva


# Almost everything in this file is contained in the Sim class
__all__ = ['Sim', 'diff_sims', 'demo', 'AlreadyRunError']

Expand Down Expand Up @@ -626,10 +628,9 @@ def step(self):
beta = cvd.default_float(self['beta'] * rel_beta)

for lkey, layer in contacts.items():
p1 = layer['p1']
p2 = layer['p2']
p1 = layer['p1'].long()
p2 = layer['p2'].long()
betas = layer['beta']

# Compute relative transmission and susceptibility
inf_variant = people.infectious * (people.infectious_variant == variant) # TODO: move out of loop?
sus_imm = people.sus_imm[variant,:]
Expand Down
61 changes: 47 additions & 14 deletions covasim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,27 +91,60 @@ def compute_trans_sus(rel_trans, rel_sus, inf, sus, beta_layer,
return rel_trans, rel_sus


@nb.njit( (nbfloat, nbint[:], nbint[:], nbfloat[:], nbfloat[:], nbfloat[:]), cache=cache, parallel=rand_parallel)
def compute_infections(beta, sources, targets, layer_betas, rel_trans, rel_sus): # pragma: no cover
import torch, time


# @nb.njit( (nbfloat, nbint[:], nbint[:], nbfloat[:], nbfloat[:], nbfloat[:]), cache=cache, parallel=rand_parallel)
# def compute_infections(beta, sources, targets, layer_betas, rel_trans, rel_sus): # pragma: no cover
# '''
# Compute who infects whom

# The heaviest step of the model, taking about 50% of the total time -- figure
# out who gets infected on this timestep. Cannot be easily parallelized since
# random numbers are used.
# '''
# source_trans = rel_trans[sources] # Pull out the transmissibility of the sources (0 for non-infectious people)
# inf_inds = source_trans.nonzero()[0] # Infectious indices -- remove noninfectious people
# betas = beta * layer_betas[inf_inds] * source_trans[inf_inds] * rel_sus[targets[inf_inds]] # Calculate the raw transmission probabilities
# nonzero_inds = betas.nonzero()[0] # Find nonzero entries
# nonzero_inf_inds = inf_inds[nonzero_inds] # Map onto original indices
# nonzero_betas = betas[nonzero_inds] # Remove zero entries from beta
# nonzero_sources = sources[nonzero_inf_inds] # Remove zero entries from the sources
# nonzero_targets = targets[nonzero_inf_inds] # Remove zero entries from the targets
# transmissions = (np.random.random(len(nonzero_betas)) < nonzero_betas).nonzero()[0] # Compute the actual infections!

# source_inds = nonzero_sources[transmissions]
# target_inds = nonzero_targets[transmissions] # Filter the targets on the actual infections
# return source_inds, target_inds


def compute_infections(beta, sources, targets, layer_betas, rel_trans, rel_sus): # pragma: no cover
'''
Compute who infects whom

The heaviest step of the model, taking about 50% of the total time -- figure
out who gets infected on this timestep. Cannot be easily parallelized since
random numbers are used.
'''
source_trans = rel_trans[sources] # Pull out the transmissibility of the sources (0 for non-infectious people)
inf_inds = source_trans.nonzero()[0] # Infectious indices -- remove noninfectious people
betas = beta * layer_betas[inf_inds] * source_trans[inf_inds] * rel_sus[targets[inf_inds]] # Calculate the raw transmission probabilities
nonzero_inds = betas.nonzero()[0] # Find nonzero entries
nonzero_inf_inds = inf_inds[nonzero_inds] # Map onto original indices
nonzero_betas = betas[nonzero_inds] # Remove zero entries from beta
nonzero_sources = sources[nonzero_inf_inds] # Remove zero entries from the sources
nonzero_targets = targets[nonzero_inf_inds] # Remove zero entries from the targets
transmissions = (np.random.random(len(nonzero_betas)) < nonzero_betas).nonzero()[0] # Compute the actual infections!
source_inds = nonzero_sources[transmissions]
target_inds = nonzero_targets[transmissions] # Filter the targets on the actual infections
return source_inds, target_inds

rel_trans = torch.from_numpy(rel_trans).cuda()
rel_sus = torch.from_numpy(rel_sus).cuda()

source_trans = rel_trans[sources.long()] # Pull out the transmissibility of the sources (0 for non-infectious people)
inf_inds = torch.flatten(source_trans.nonzero().long()) # Infectious indices -- remove noninfectious people
betas = torch.flatten(beta * layer_betas[inf_inds] * source_trans[inf_inds] * rel_sus[targets[inf_inds]]) # Calculate the raw transmission probabilities
nonzero_inds = torch.flatten(betas.nonzero()) # Find nonzero entries
nonzero_inf_inds = torch.flatten(inf_inds[nonzero_inds]) # Map onto original indices
nonzero_betas = torch.flatten(betas[nonzero_inds]) # Remove zero entries from beta
nonzero_sources = torch.flatten(sources[nonzero_inf_inds]) # Remove zero entries from the sources
nonzero_targets = torch.flatten(targets[nonzero_inf_inds]) # Remove zero entries from the targets
transmissions = torch.flatten((torch.rand(len(nonzero_betas)).cuda() < nonzero_betas).nonzero()) # Compute the actual infections!

source_inds = torch.flatten(nonzero_sources[transmissions])
target_inds = torch.flatten(nonzero_targets[transmissions]) # Filter the targets on the actual infections

return source_inds.cpu().numpy(), target_inds.cpu().numpy()



@nb.njit((nbint[:], nbint[:], nb.int64[:]), cache=cache)
Expand Down