Skip to content

Commit a4f5b9b

Browse files
authored
Merge branch 'master' into jdaw/add-indel-encoding
2 parents 527b686 + 4d22a0a commit a4f5b9b

File tree

1 file changed

+40
-21
lines changed

1 file changed

+40
-21
lines changed

docs/source/examples.rst

+40-21
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,16 @@ Training
2525
.. code-block:: python
2626
2727
# Import nemo and variantworks modules
28+
import os
2829
import nemo
29-
from variantworks.dataloader import *
30-
from variantworks.io.vcfio import *
31-
from variantworks.networks import *
32-
from variantworks.sample_encoders import *
30+
from variantworks.dataloader import ReadPileupDataLoader
31+
from variantworks.io.vcfio import VCFReader
32+
from variantworks.networks import AlexNet
33+
from variantworks.sample_encoder import PileupEncoder, ZygosityLabelEncoder
3334
3435
# Create neural factory
3536
nf = nemo.core.NeuralModuleFactory(
36-
placement=nemo.core.neural_factory.DeviceType.GPU, checkpoint_dir=tempdir)
37+
placement=nemo.core.neural_factory.DeviceType.GPU, checkpoint_dir="./")
3738
3839
# Create pileup encoder by selecting layers to encode. More encoding layers
3940
# can be found in the documentation for PilupEncoder class.
@@ -77,14 +78,27 @@ Training
7778
# Logger callback
7879
logger_callback = nemo.core.SimpleLossLoggerCallback(
7980
tensors=[vz_loss],
80-
print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'))
81+
print_func=lambda x: nemo.logging.info(f'Train Loss: {str(x[0].item())}')
82+
)
83+
84+
# Checkpointing models through NeMo callback
85+
checkpoint_callback = nemo.core.CheckpointCallback(
86+
folder="./",
87+
load_from_folder=None,
88+
# Checkpointing frequency in steps
89+
step_freq=-1,
90+
# Checkpointing frequency in epochs
91+
epoch_freq=1,
92+
# Number of checkpoints to keep
93+
checkpoints_to_keep=1,
94+
# If True, CheckpointCallback will raise an Error if restoring fails
95+
force_load=False
8196
)
8297
8398
# Kick off training
8499
nf.train([vz_loss],
85-
callbacks=[logger_callback,
86-
checkpoint_callback, evaluator_callback],
87-
optimization_params={"num_epochs": 4, "lr": 0.001},
100+
callbacks=[logger_callback, checkpoint_callback],
101+
optimization_params={"num_epochs": 10, "lr": 0.001},
88102
optimizer="adam")
89103
90104
@@ -96,28 +110,30 @@ The inference pipeline works in a very similar fashion, except the final NeMo DA
96110
.. code-block:: python
97111
98112
# Import nemo and variantworks modules
113+
import os
99114
import nemo
100-
from variantworks.dataloader import *
101-
from variantworks.io.vcfio import *
102-
from variantworks.networks import *
103-
from variantworks.sample_encoders import *
104-
from variantworks.result_writer import *
115+
import torch
116+
from variantworks.dataloader import ReadPileupDataLoader
117+
from variantworks.io.vcfio import VCFReader
118+
from variantworks.networks import AlexNet
119+
from variantworks.sample_encoder import PileupEncoder, ZygosityLabelDecoder
120+
from variantworks.result_writer import VCFResultWriter
105121
106122
# Create neural factory. In this case, the checkpoint_dir has to be set for NeMo to pick
107123
# up a pre-trained model.
108124
nf = nemo.core.NeuralModuleFactory(
109-
placement=nemo.core.neural_factory.DeviceType.GPU, checkpoint_dir=model_dir)
110-
111-
# Neural Network
112-
model = AlexNet(num_input_channels=len(
113-
encoding_layers), num_output_logits=3)
125+
placement=nemo.core.neural_factory.DeviceType.GPU, checkpoint_dir="./")
114126
115127
# Dataset generation is done in a similar manner. It's important to note that the encoder used
116128
# for inference much match that for training.
117129
encoding_layers = [PileupEncoder.Layer.READ, PileupEncoder.Layer.BASE_QUALITY]
118130
pileup_encoder = PileupEncoder(
119131
window_size=100, max_reads=100, layers=encoding_layers)
120132
133+
# Neural Network
134+
model = AlexNet(num_input_channels=len(
135+
encoding_layers), num_output_logits=3)
136+
121137
# Similar to training, a dataloader needs to be setup for the relevant datasets. In the case of
122138
# inference, it doesn't matter if the files are tagged as false positive or not. Each example will be
123139
# evaluated by the network. For simplicity the example is using the same dataset from training.
@@ -135,21 +151,24 @@ The inference pipeline works in a very similar fashion, except the final NeMo DA
135151
vz = model(encoding=encoding)
136152
137153
# Invoke the "infer" action.
138-
results = nf.infer([vz], checkpoint_dir=model_dir, verbose=True)
154+
results = nf.infer([vz], checkpoint_dir="./", verbose=True)
139155
140156
# Instantiate a decoder that converts the predicted output of the network to
141157
# a zygosity enum.
142158
zyg_decoder = ZygosityLabelDecoder()
143159
144160
# Decode inference results to labels
161+
inferred_zygosity = []
145162
for tensor_batches in results:
146163
for batch in tensor_batches:
147164
predicted_classes = torch.argmax(batch, dim=1)
148165
inferred_zygosity += [zyg_decoder(pred)
149166
for pred in predicted_classes]
150167
151168
# Use the VCFResultWriter to output predicted zygosities to a VCF file.
152-
result_writer = VCFResultWriter(vcf_loader, inferred_zygosity)
169+
result_writer = VCFResultWriter(vcf_loader,
170+
inferred_zygosities=inferred_zygosity,
171+
output_location="./")
153172
154173
result_writer.write_output()
155174

0 commit comments

Comments
 (0)