Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial refactoring #4

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 156 additions & 117 deletions convert2HMI.py
Original file line number Diff line number Diff line change
@@ -1,169 +1,208 @@
import os
import sys

import argparse
import yaml
import logging
from itertools import chain
from pathlib import Path

import numpy as np
from astropy.io import fits
import astropy.units as u
import sunpy.map

import torch
import yaml
from torch.utils.data.dataloader import DataLoader

from source.models.model_manager import BaseScaler
from source.dataset import FitsFileDataset
from source.data_utils import get_array_radius, get_image_from_array, plot_magnetogram
from source.utils import get_logger

def get_logger(name):

def get_config(instrument, fulldisk, zero_outside, add_noise, no_rescale, **kwargs):
"""
Return a logger for current module
Get config object setting values passed.

Parameters
----------
instrument : str
Instrument name
fulldisk : bool
Fulldisk based inference, default is patch based
zero_outside :
Set region outside solar radius to zeros instead of default nan values
add_noise : float (optional)
Scale or standard deviation of noise to add
no_rescale : bool (optional)
Disable rescaling

Returns
-------

logger : logger instance

tuple
Run and config dict
"""
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s %(name)s: %(message)s",
datefmt="%Y-%m-%d - %H:%M:%S")
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.DEBUG)
console.setFormatter(formatter)

logfile = logging.FileHandler('run.log', 'w')
logfile.setLevel(logging.DEBUG)
logfile.setFormatter(formatter)

logger.addHandler(console)
logger.addHandler(logfile)

return logger

if __name__ == '__main__':
logger = get_logger(__name__)
parser = argparse.ArgumentParser()
parser.add_argument('--instrument', required=True)
parser.add_argument('--data_path')
parser.add_argument('--destination')
parser.add_argument('--add_noise')
parser.add_argument('--plot', action='store_true')
parser.add_argument('--overwrite', action='store_true')
parser.add_argument('--use_patches', action='store_true')
parser.add_argument('--zero_outside', action='store_true')
parser.add_argument('--no_rescale', action='store_true')


args = parser.parse_args()
instrument = args.instrument.lower()

if instrument == 'mdi':
run = 'checkpoints/mdi/20200312194454_HighResNet_RPRCDO_SSIMGradHistLoss_mdi_19'
run_dir = Path('checkpoints/mdi/20200312194454_HighResNet_RPRCDO_SSIMGradHistLoss_mdi_19')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should alway try to use pathlib.Path for all paths in general also if we want this work on window is is a must.

elif instrument == 'gong':
run = 'checkpoints/gong/20200321142757_HighResNet_RPRCDO_SSIMGradHistLoss_gong_1'
else:
raise RuntimeError(f'mdi and gong are the only valid instruments.')
run_dir = Path('checkpoints/gong/20200321142757_HighResNet_RPRCDO_SSIMGradHistLoss_gong_1')

with open(run + '.yml', 'r') as stream:
with run_dir.with_suffix('.yml').open() as stream:
config_data = yaml.load(stream, Loader=yaml.SafeLoader)

config_data['cli'] = {'fulldisk': fulldisk,
'zero_outside': zero_outside,
'add_noise': add_noise}

config_data['instrument'] = instrument
data_config = config_data['data']
norm = 3500
if 'normalisation' in data_config.keys():
norm = data_config['normalisation']

if 'normalisation' not in data_config.keys():
data_config['normalisation'] = 3500.0

padding = np.nan
if args.zero_outside:
if zero_outside:
padding = 0
data_config['padding'] = padding

rescale = True
if args.no_rescale:
if no_rescale:
rescale = False
data_config['rescale'] = rescale

net_config = config_data['net']
model_name = net_config['name']
upscale_factor = 4
if 'upscale_factor' in net_config.keys():
upscale_factor = net_config['upscale_factor']

model = BaseScaler.from_dict(config_data)
if 'upscale_factor' not in net_config.keys():
net_config['upscale_factor'] = 4

device = torch.device("cpu")
model = model.to(device)
return run_dir, config_data


def get_model(run, config):
"""
Get a model based on the run and config data

checkpoint = torch.load(run, map_location='cpu')
try:
Parameters
----------
run : pathlib.Path
Path to run directory
config : dict
Config data

Returns
-------
source.model.model_manger.TemplateModel
The model
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config['device'] = device
amodel = BaseScaler.from_dict(config)
amodel = amodel.to(device)
checkpoint = torch.load(run.as_posix(), map_location='cpu')
state_dict = {}
for key, value in checkpoint['model_state_dict'].items():
if key.startswith('module'):
state_dict['.'.join(key.split('.')[1:])] = value
else:
state_dict[key] = value
amodel.load_state_dict(state_dict)

return amodel


def convert(in_file, out_file, config, patchsize=32):
"""
Convert a file to HMI

Parameters
----------
in_file : pathlib.Path
Input fits file
out_file : pathlib.Path
Output fits file
config : dict
Configuration dictionary
patchsize : int
Size of the patches created

Returns
-------

"""
# Really slow imports so only import if we reach the point where it is needed
from source.dataset import FitsFileDataset
from source.data_utils import get_array_radius, get_image_from_patches, plot_magnetogram

norm = config['data']['normalisation']
device = config['device']
fulldisk = config['cli']['fulldisk']

file_dset = FitsFileDataset(in_file, patchsize, config)
inferred = None
# Try full disk
if fulldisk:
try:
model.load_state_dict(checkpoint['model_state_dict'])
logger.info('Attempting full disk inference...')
in_fd = np.stack([file_dset.map.data, get_array_radius(file_dset.map)], axis=0)
inferred = model.forward(
torch.from_numpy(in_fd[None]).to(device).float()).detach().numpy()[
0, ...] * norm
logger.info('Success.')

except:
state_dict = {}
for key, value in checkpoint['model_state_dict'].items():
state_dict['.'.join(key.split('.')[1:])] = value
model.load_state_dict(state_dict)
except:
state_dict = {}
for key, value in checkpoint['model_state_dict'].items():
state_dict['.'.join(np.append(['module'], key.split('.')[0:]))] = value
model.load_state_dict(state_dict)
except Exception:
logger.info('Full disk inference failed', exc_info=True)

else:
logger.info('Attempting inference on patches...')
dataloader = DataLoader(file_dset, batch_size=8, shuffle=False)

list_of_files = []
for (dirpath, dirnames, filenames) in os.walk(args.data_path):
list_of_files += [os.path.join(dirpath, file) for file in filenames if file.endswith('.fits') or file.endswith('.fits.gz')]
output_patches = []

os.makedirs(args.destination, exist_ok=True)
for patch in dataloader:
patch.to(device)
output = model.forward(patch) * norm

for file in list_of_files:
output_patches.append(output.detach().cpu().numpy())

logger.info(f'Processing {file}')
inferred = get_image_from_patches(output_patches)
logger.info(f'Success.')

output_file = args.destination + '/' + '.'.join(file.split('/')[-1].split('.gz')[0].split('.')[0:-1])
if os.path.exists(output_file + '_HR.fits') and not args.overwrite:
logger.info(f'{file} already exists')
if inferred:
inferred_map = file_dset.create_new_map(inferred, model.name)
inferred_map.save(out_file.as_posix(), overwrite=True)

else:
if args.plot:
plot_magnetogram(inferred_map, out_file.with_suffix('.png'))

file_dset = FitsFileDataset(file, 32, norm, instrument, rescale, upscale_factor)
del inferred_map

# Try full disk
success_sw = False
if not args.use_patches:
success_sw = True
try:
logger.info(f'Attempting full disk inference...')
in_fd = np.stack([file_dset.map.data, get_array_radius(file_dset.map)], axis=0)
inferred = model.forward(torch.from_numpy(in_fd[None]).to(device).float()).detach().numpy()[0,...]*norm
logger.info(f'Success.')

except Exception as e:
logger.info(f'Failure. {e}')
success_sw = False
if __name__ == '__main__':
logging.root.setLevel('INFO')
logger = get_logger(__name__)

if not success_sw or args.use_patches:
logger.info(f'Attempting inference on patches...')
dataloader = DataLoader(file_dset, batch_size=8, shuffle=False)
parser = argparse.ArgumentParser()
parser.add_argument('--instrument', required=True, choices=['mdi', 'gong'])
parser.add_argument('--source_dir', required=True, type=str)
parser.add_argument('--destination_dir', required=True, type=str)
parser.add_argument('--add_noise', type=float)
parser.add_argument('--plot', action='store_true')
parser.add_argument('--overwrite', action='store_true')
parser.add_argument('--fulldisk', action='store_true')
parser.add_argument('--zero_outside', action='store_true')
parser.add_argument('--no_rescale', action='store_true')

output_patches = []
args = parser.parse_args()

for input in dataloader:
source_dir = Path(args.source_dir)
destination_dir = Path(args.destination_dir)
overwrite = args.overwrite

input = input.to(device)
output = model.forward(input) * norm
checkpoint_dir, config_data = get_config(**vars(args))

output_patches.append(output.detach().cpu().numpy())
model = get_model(checkpoint_dir, config_data)

inferred = get_image_from_array(output_patches)
logger.info(f'Success.')
source_files = chain(source_dir.rglob('*.fits'), source_dir.rglob('*.fits.gz'))

inferred_map = file_dset.create_new_map(inferred, upscale_factor, args.add_noise, model_name, config_data, padding)
inferred_map.save(output_file + '_HR.fits', overwrite=True)
destination_dir.mkdir(exist_ok=True, parents=True)

if args.plot:
plot_magnetogram(inferred_map, output_file + '_HR.png')
for file in source_files:
logger.info(f'Processing {file}')
out_path = destination_dir / (file.stem + '_HR.fits')

del inferred_map
if out_path.exists() and not overwrite:
logger.info(f'{file} already exists')
else:
convert(file, out_path, config_data)
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,5 @@ torch==1.1.0
torchvision==0.3.0
tensorboard==1.14.0
tensorboardX==1.8
gitpython==3.1.0
tensorflow==1.15.2
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should remove any other used requirements

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should there be a python requirement in here?

scikit-learn==0.21.3
gcsfs==0.2.3
Loading