-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
revised and simplified scripts. Converted bash scripts to python. Upd…
…ated with random access problems (compressive random access,massive MIMO channel estimation)
- Loading branch information
1 parent
120ec7a
commit 54a39d3
Showing
23 changed files
with
962 additions
and
1,949 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
*.pyc | ||
*.npz | ||
*.mat | ||
*.swp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/usr/bin/python | ||
from __future__ import division | ||
from __future__ import print_function | ||
""" | ||
This file serves as an example of how to | ||
a) select a problem to be solved | ||
b) select a network type | ||
c) train the network to minimize recovery MSE | ||
""" | ||
import numpy as np | ||
import os | ||
|
||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # BE QUIET!!!! | ||
import tensorflow as tf | ||
|
||
np.random.seed(1) # numpy is good about making repeatable output | ||
tf.set_random_seed(1) # on the other hand, this is basically useless (see issue 9171) | ||
|
||
# import our problems, networks and training modules | ||
from tools import problems,networks,train | ||
|
||
# Create the basic problem structure. | ||
prob = problems.bernoulli_gaussian_trial(kappa=None,M=250,N=500,L=1000,pnz=.1,SNR=40) #a Bernoulli-Gaussian x, noisily observed through a random matrix | ||
#prob = problems.random_access_problem(2) # 1 or 2 for compressive random access or massive MIMO | ||
|
||
# build a LAMP network to solve the problem and get the intermediate results so we can greedily extend and then refine(fine-tune) | ||
layers = networks.build_LAMP(prob,T=6,shrink='bg',untied=False) | ||
|
||
# plan the learning | ||
training_stages = train.setup_training(layers,prob,trinit=1e-3,refinements=(.5,.1,.01) ) | ||
|
||
# do the learning (takes a while) | ||
sess = train.do_training(training_stages,prob,'LAMP_bg_giid.npz') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/usr/bin/python | ||
from __future__ import division | ||
from __future__ import print_function | ||
""" | ||
This file serves as an example of how to | ||
a) select a problem to be solved | ||
b) select a network type | ||
c) train the network to minimize recovery MSE | ||
""" | ||
import numpy as np | ||
import os | ||
|
||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # BE QUIET!!!! | ||
import tensorflow as tf | ||
|
||
np.random.seed(1) # numpy is good about making repeatable output | ||
tf.set_random_seed(1) # on the other hand, this is basically useless (see issue 9171) | ||
|
||
# import our problems, networks and training modules | ||
from tools import problems,networks,train | ||
|
||
# Create the basic problem structure. | ||
prob = problems.bernoulli_gaussian_trial(kappa=None,M=250,N=500,L=1000,pnz=.1,SNR=40) #a Bernoulli-Gaussian x, noisily observed through a random matrix | ||
#prob = problems.random_access_problem(2) # 1 or 2 for compressive random access or massive MIMO | ||
|
||
# build a LISTA network to solve the problem and get the intermediate results so we can greedily extend and then refine(fine-tune) | ||
layers = networks.build_LISTA(prob,T=6,initial_lambda=.1,untied=False) | ||
|
||
# plan the learning | ||
training_stages = train.setup_training(layers,prob,trinit=1e-3,refinements=(.5,.1,.01) ) | ||
|
||
# do the learning (takes a while) | ||
sess = train.do_training(training_stages,prob,'LISTA_bg_giid.npz') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#!/usr/bin/python | ||
from __future__ import division | ||
from __future__ import print_function | ||
""" | ||
This file serves as an example of how to | ||
a) select a problem to be solved | ||
b) select a network type | ||
c) train the network to minimize recovery MSE | ||
""" | ||
import numpy as np | ||
import os | ||
|
||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # BE QUIET!!!! | ||
import tensorflow as tf | ||
|
||
np.random.seed(1) # numpy is good about making repeatable output | ||
tf.set_random_seed(1) # on the other hand, this is basically useless (see issue 9171) | ||
|
||
# import our problems, networks and training modules | ||
from tools import problems,networks,train | ||
|
||
# Create the basic problem structure. | ||
prob = problems.bernoulli_gaussian_trial(kappa=None,M=250,N=500,L=1000,pnz=.1,SNR=40) #a Bernoulli-Gaussian x, noisily observed through a random matrix | ||
#prob = problems.random_access_problem(2) # 1 or 2 for compressive random access or massive MIMO | ||
|
||
# build an LVAMP network to solve the problem and get the intermediate results so we can greedily extend and then refine(fine-tune) | ||
layers = networks.build_LVAMP(prob,T=6,shrink='bg') | ||
#layers = networks.build_LVAMP_dense(prob,T=3,shrink='pwgrid') | ||
|
||
# plan the learning | ||
training_stages = train.setup_training(layers,prob,trinit=1e-4,refinements=(.5,.1,.01)) | ||
|
||
# do the learning (takes a while) | ||
sess = train.do_training(training_stages,prob,'LVAMP_bg_giid.npz') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,15 @@ | ||
# What is this? | ||
|
||
This project contains scripts to reproduce experiments from the paper | ||
[Onsager-Corrected Deep Networks for Sparse Linear Inverse Problems](https://arxiv.org/pdf/1612.01183) | ||
[AMP-Inspired Deep Networks for Sparse Linear Inverse Problems](http://ieeexplore.ieee.org/document/7934066/) | ||
by | ||
[Mark Borgerding](mailto://[email protected]) | ||
and | ||
, | ||
[Phil](mailto://[email protected]) | ||
[Schniter](http://www2.ece.ohio-state.edu/~schniter) | ||
|
||
, and [Sundeep Rangan](http://engineering.nyu.edu/people/sundeep-rangan). | ||
To appear in IEEE Transactions on Signal Processing. | ||
See also the related [preprint](https://arxiv.org/pdf/1612.01183) | ||
|
||
# The Problem of Interest | ||
|
||
|
@@ -19,12 +21,13 @@ The included scripts | |
- are generally written in python and require [TensorFlow](http://www.tensorflow.org), | ||
- work best with a GPU, | ||
- generate synthetic data as needed, | ||
- are known to work with CentOS 7 Linux and TensorfFlow 0.9, | ||
- may sometimes be written in octave/matlab .m files. | ||
- are known to work with CentOS 7 Linux and TensorfFlow 1.1, | ||
- are sometimes be written in octave/matlab .m files. | ||
|
||
## If you are just looking for an implementation of VAMP ... | ||
|
||
You might prefer the Matlab code in [GAMP](https://sourceforge.net/projects/gampmatlab/)/code/VAMP/ | ||
or the python code in [Vampyre](https://github.com/GAMPTeam/vampyre). | ||
|
||
# Description of Files | ||
|
||
|
@@ -34,17 +37,6 @@ Creates numpy archives (.npz) and matlab (.mat) files with (y,x,A) for the spars | |
These files are not really necessary for any of the deep-learning scripts, which generate the problem on demand. | ||
They are merely provided for better understanding the specific realizations used in the experiments. | ||
|
||
e.g. | ||
``` | ||
$ ./save_problem.py | ||
... | ||
saved numpy archive problem_Giid.npz | ||
saved matlab file problem_Giid.mat | ||
... | ||
saved numpy archive problem_k15.npz | ||
saved matlab file problem_k15.mat | ||
``` | ||
|
||
## [ista_fista_amp.m](ista_fista_amp.m) | ||
|
||
Using the .mat files created by save_problem.py, this octave/matlab script tests the performance of non-learned algorithms ISTA, FISTA, and AMP. | ||
|
@@ -61,98 +53,15 @@ ISTA reached NMSE=-35dB at iteration 3420 | |
ISTA terminal NMSE=-36.7419 dB | ||
``` | ||
|
||
## [lista.py](lista.py) | ||
|
||
This is an implementation of LISTA _Learned Iterative Soft Thresholding Algorithm_ by (Gregor&LeCun, 2010 ICML). | ||
|
||
e.g. To reproduce the `LISTA` trace from Fig.9, | ||
``` | ||
$ ./lista.py --T 1 --save /tmp/T1.npz --trainRate=1e-3 --refinements=4 --stopAfter=20 | ||
... | ||
step=15990 elapsed=133.239667892 nic=319 nmse_test=-6.40807334129 nmse_val=-6.46110795422 nmse_val0=-0.806287242 | ||
no improvement in 320 checks,Test NMSE=-6.408dB | ||
$ for T in {2..20};do ./lista.py --T $T --save /tmp/T${T}.npz --load /tmp/T$(($T-1)).npz --setup --trainRate=1e-3 --refinements=4 --stopAfter=20 --summary /tmp/${T}.sum || break ;done | ||
... | ||
``` | ||
The `nmse_val` is the quantity that is plotted in the paper. It is from a mini-batch that is used for training or iany decisions. The `nmse_test` is from a minibatch that is also not trained, but it _is_ used for decisions about training step size, and termination criteria. This convention holds for all experiments. | ||
|
||
## [lamp_vamp.py](lamp_vamp.py) | ||
|
||
Learns the parameters for Learned AMP (LAMP) or Vector AMP(VAMP) with a variety of shrinkage functions. | ||
This script may be called independently or from run_lamp_vamp.sh. | ||
|
||
e.g. The following generates the `matched VAMP` trace from Fig.12 | ||
``` | ||
$ for T in {1..15};do ./lamp_vamp.py --matched --T $T --summary=matched${T}.sum;done | ||
... | ||
$ for T in {1..15};do echo -n "T=$T "; grep nmse_val= matched${T}.sum;done | ||
T=1 nmse_val=-6.74708897423 | ||
T=2 nmse_val=-12.5694582254 | ||
T=3 nmse_val=-18.8778007058 | ||
T=4 nmse_val=-25.7153599678 | ||
T=5 nmse_val=-32.8098204058 | ||
T=6 nmse_val=-39.1792426565 | ||
T=7 nmse_val=-43.3195721343 | ||
T=8 nmse_val=-44.9222227945 | ||
T=9 nmse_val=-45.3680144768 | ||
T=10 nmse_val=-45.4783550406 | ||
T=11 nmse_val=-45.4985886728 | ||
T=12 nmse_val=-45.5054164287 | ||
T=13 nmse_val=-45.5063294603 | ||
T=14 nmse_val=-45.50776381 | ||
T=15 nmse_val=-45.5077351689 | ||
``` | ||
Here, the `--matched` argument bypasses training by forcing some argument values, specifically `--vamp --shrink bg --trainRate=0`. | ||
|
||
## [run_lamp_vamp.sh](run_lamp_vamp.sh) | ||
|
||
bash script to drive lamp_vamp.py with different shrinkage functions, algorithms, matrix types, etc. | ||
This takes days to run, even with a fast GPU. | ||
|
||
## [let_vamp_off_leash.py](let_vamp_off_leash.py) | ||
|
||
This demonstrates that matched VAMP represents a fixed point for a Learned VAMP (LVAMP) network. | ||
The network is given the benefit of a large number of parameters. | ||
One might expect that deep learning/backpropagation would yield some improvements over the prescribed structure and values of | ||
VAMP. | ||
|
||
Notably we find that **backpropagation yields no improvement to LVAMP when initialized with matched parameters**. | ||
``` | ||
$ ./let_vamp_off_leash.py --T 6 --trainRate=1e-5 --refinements=0 --stopAfter=200 | ||
... | ||
step=10 elapsed=0.605620861053 nic=1 nmse_test=-38.9513696476 nmse_val=-39.2820689456 nmse_val0=-39.2881639853 | ||
... | ||
step=1990 elapsed=57.8836979866 nic=199 nmse_test=-38.9146799477 nmse_val=-39.2407649471 nmse_val0=-39.2881639853 | ||
``` | ||
|
||
Then the initialization is slightly perturbed away from matched parameters such that the initial performrance is almost 6dB | ||
worse. | ||
We see that backpropagation does its job and finds its way back to | ||
approximately the same level as with the matched parameters. The slight difference is explainable by different realizations of the training minibatches. | ||
``` | ||
$ ./let_vamp_off_leash.py --T 6 --trainRate=1e-5 --refinements=0 --stopAfter=200 --errTheta .2 --errV 1e-3 --errISU 1e-3 --errRS2 1e-3 | ||
... | ||
step=10 elapsed=1.27413392067 nic=0 nmse_test=-32.9629743724 nmse_val=-33.0911397806 nmse_val0=-30.0131020631 | ||
... | ||
step=10050 elapsed=585.011854887 nic=199 nmse_test=-38.9203152625 nmse_val=-39.2310257925 nmse_val0=-30.0131020631 | ||
``` | ||
|
||
|
||
## [shrinkage.py](shrinkage.py) | ||
## [LISTA.py](LISTA.py) | ||
|
||
python module which defines the shrinkage functions we investigated and parameterized | ||
This is an example implementation of LISTA _Learned Iterative Soft Thresholding Algorithm_ by (Gregor&LeCun, 2010 ICML). | ||
|
||
- soft-threshold (scaled) | ||
- piecewise-linear | ||
- exponential | ||
- spline-based | ||
- Bernoulli-Gaussian MMSE | ||
## [LAMP.py](LAMP.py) | ||
|
||
## [utils.py](utils.py) | ||
Example of Learned AMP (LAMP) with a variety of shrinkage functions. | ||
|
||
Various python/tensorflow utilities. | ||
## [LVAMP.py](LVAMP.py) | ||
|
||
## [utils.sh](utils.sh) | ||
Example of Learned Vector AMP (LVAMP). | ||
|
||
Various shell script utility functions. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.