From fd48e97ab9c43a3bf3374b4e752bfe8a3d6ce69f Mon Sep 17 00:00:00 2001 From: Shane Maloney Date: Sat, 25 Apr 2020 18:20:42 +0100 Subject: [PATCH 1/2] Inital refactoring * Now works on some data obatained from vso * Removed some repeated code * Formating and imports --- convert2HMI.py | 88 +++++---------- source/data_utils.py | 256 +++++++++++++------------------------------ source/dataset.py | 67 +++++++---- source/load.py | 4 +- source/utils.py | 8 +- 5 files changed, 149 insertions(+), 274 deletions(-) diff --git a/convert2HMI.py b/convert2HMI.py index 153a61c..74d4915 100644 --- a/convert2HMI.py +++ b/convert2HMI.py @@ -1,52 +1,20 @@ -import os -import sys - import argparse -import yaml -import logging +import os import numpy as np -from astropy.io import fits -import astropy.units as u -import sunpy.map - import torch +import yaml from torch.utils.data.dataloader import DataLoader -from source.models.model_manager import BaseScaler +from source.data_utils import get_array_radius, get_image_from_patches, plot_magnetogram from source.dataset import FitsFileDataset -from source.data_utils import get_array_radius, get_image_from_array, plot_magnetogram - -def get_logger(name): - """ - Return a logger for current module - Returns - ------- - - logger : logger instance - - """ - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s %(name)s: %(message)s", - datefmt="%Y-%m-%d - %H:%M:%S") - console = logging.StreamHandler(sys.stdout) - console.setLevel(logging.DEBUG) - console.setFormatter(formatter) - - logfile = logging.FileHandler('run.log', 'w') - logfile.setLevel(logging.DEBUG) - logfile.setFormatter(formatter) - - logger.addHandler(console) - logger.addHandler(logfile) - - return logger +from source.models.model_manager import BaseScaler +from source.utils import get_logger if __name__ == '__main__': logger = get_logger(__name__) parser = argparse.ArgumentParser() - parser.add_argument('--instrument', required=True) + parser.add_argument('--instrument', required=True, choices=['mdi', 'gong']) parser.add_argument('--data_path') parser.add_argument('--destination') parser.add_argument('--add_noise') @@ -56,16 +24,13 @@ def get_logger(name): parser.add_argument('--zero_outside', action='store_true') parser.add_argument('--no_rescale', action='store_true') - args = parser.parse_args() instrument = args.instrument.lower() - if instrument == 'mdi': + if args.instrument == 'mdi': run = 'checkpoints/mdi/20200312194454_HighResNet_RPRCDO_SSIMGradHistLoss_mdi_19' - elif instrument == 'gong': + elif args.instrument == 'gong': run = 'checkpoints/gong/20200321142757_HighResNet_RPRCDO_SSIMGradHistLoss_gong_1' - else: - raise RuntimeError(f'mdi and gong are the only valid instruments.') with open(run + '.yml', 'r') as stream: config_data = yaml.load(stream, Loader=yaml.SafeLoader) @@ -95,26 +60,20 @@ def get_logger(name): model = model.to(device) checkpoint = torch.load(run, map_location='cpu') - try: - - try: - model.load_state_dict(checkpoint['model_state_dict']) - except: - state_dict = {} - for key, value in checkpoint['model_state_dict'].items(): - state_dict['.'.join(key.split('.')[1:])] = value - model.load_state_dict(state_dict) - except: - state_dict = {} - for key, value in checkpoint['model_state_dict'].items(): - state_dict['.'.join(np.append(['module'], key.split('.')[0:]))] = value - model.load_state_dict(state_dict) + state_dict = {} + for key, value in checkpoint['model_state_dict'].items(): + if key.startswith('module'): + state_dict['.'.join(key.split('.')[1:])] = value + else: + state_dict[key] = value + model.load_state_dict(state_dict) list_of_files = [] for (dirpath, dirnames, filenames) in os.walk(args.data_path): - list_of_files += [os.path.join(dirpath, file) for file in filenames if file.endswith('.fits') or file.endswith('.fits.gz')] + list_of_files += [os.path.join(dirpath, file) for file in filenames if + file.endswith('.fits') or file.endswith('.fits.gz')] os.makedirs(args.destination, exist_ok=True) @@ -122,7 +81,8 @@ def get_logger(name): logger.info(f'Processing {file}') - output_file = args.destination + '/' + '.'.join(file.split('/')[-1].split('.gz')[0].split('.')[0:-1]) + output_file = args.destination + '/' + '.'.join( + file.split('/')[-1].split('.gz')[0].split('.')[0:-1]) if os.path.exists(output_file + '_HR.fits') and not args.overwrite: logger.info(f'{file} already exists') @@ -137,7 +97,9 @@ def get_logger(name): try: logger.info(f'Attempting full disk inference...') in_fd = np.stack([file_dset.map.data, get_array_radius(file_dset.map)], axis=0) - inferred = model.forward(torch.from_numpy(in_fd[None]).to(device).float()).detach().numpy()[0,...]*norm + inferred = model.forward( + torch.from_numpy(in_fd[None]).to(device).float()).detach().numpy()[ + 0, ...] * norm logger.info(f'Success.') except Exception as e: @@ -151,16 +113,16 @@ def get_logger(name): output_patches = [] for input in dataloader: - input = input.to(device) output = model.forward(input) * norm output_patches.append(output.detach().cpu().numpy()) - inferred = get_image_from_array(output_patches) + inferred = get_image_from_patches(output_patches) logger.info(f'Success.') - inferred_map = file_dset.create_new_map(inferred, upscale_factor, args.add_noise, model_name, config_data, padding) + inferred_map = file_dset.create_new_map(inferred, upscale_factor, args.add_noise, + model_name, config_data, padding) inferred_map.save(output_file + '_HR.fits', overwrite=True) if args.plot: diff --git a/source/data_utils.py b/source/data_utils.py index 97b7adb..a3bbbed 100644 --- a/source/data_utils.py +++ b/source/data_utils.py @@ -1,15 +1,11 @@ import numpy as np import math -import datetime from astropy import units as u from sklearn.feature_extraction import image from astropy.coordinates import SkyCoord from sunpy.map import Map from astropy.io import fits -from astropy.time import Time -from astropy.coordinates import solar_system_ephemeris, EarthLocation, get_body - import matplotlib.pyplot as plt @@ -45,172 +41,54 @@ def map_prep(file, instrument, *keyward_args): if len(hdul) == 2: sun_map = Map(hdul[1].data, hdul[1].header) - - elif len(hdul) == 1: - if instrument == 'mdi': - - header = hdul[0].header - if header['SOLAR_P0']: - header['RSUN_OBS'] = header['OBS_R0'] - header['RSUN_REF'] = 696000000 - header['CROTA2'] = -header['SOLAR_P0'] - header['CRVAL1'] = 0.000000 - header['CRVAL2'] = 0.000000 - header['CUNIT1'] = 'arcsec' - header['CUNIT2'] = 'arcsec' - header['DSUN_OBS'] = header['OBS_DIST'] - header['DSUN_REF'] = 1 - - try: - header.pop('SOLAR_P0') - header.pop('OBS_DIST') - header.pop('OBS_R0') - except: - pass - - data = hdul[0].data - - if instrument == 'gong': - - header = hdul[0].header - if len(header['DATE-OBS'])<22: - header['RSUN_OBS'] = header['RADIUS'] * 180 / np.pi * 60 * 60 - header['RSUN_REF'] = 696000000 - header['CROTA2'] = 0 - header['CUNIT1'] = 'arcsec' - header['CUNIT2'] = 'arcsec' - header['DSUN_OBS'] = header['DISTANCE'] * 149597870691 - header['DSUN_REF'] = 149597870691 - header['cdelt1'] = 2.5534 - header['cdelt2'] = 2.5534 - - header['CTYPE1'] = 'HPLN-TAN' - header['CTYPE2'] = 'HPLT-TAN' - - - date = header['DATE-OBS'] - header['DATE-OBS'] = date[0:4] + '-' + date[5:7] + '-' + date[8:10] + 'T' + header['TIME-OBS'][0:11] - - data = hdul[0].data - - if instrument == 'spmg': - header = hdul[0].header - header['cunit1'] = 'arcsec' - header['cunit2'] = 'arcsec' - header['CDELT1'] = header['CDELT1A'] - header['CDELT2'] = header['CDELT2A'] - header['CRVAL1'] = 0 - header['CRVAL2'] = 0 - header['RSUN_OBS'] = header['EPH_R0 '] - header['CROTA2'] = 0 - header['CRPIX1'] = header['CRPIX1A'] - header['CRPIX2'] = header['CRPIX2A'] - header['PC2_1'] = 0 - header['PC1_2'] = 0 + data = sun_map.data.astype('>f4') + header = sun_map.meta + else: + data = hdul[0].data + header = hdul[0].header + + if instrument == 'mdi': + if 'SOLAR_P0' in header: + header['RSUN_OBS'] = header['OBS_R0'] header['RSUN_REF'] = 696000000 - - # Adding distance to header - t = Time(header['DATE-OBS']) - loc = EarthLocation.of_site('kpno') - with solar_system_ephemeris.set('builtin'): - sun = get_body('sun', t, loc) - header['DSUN_OBS'] = sun.distance.to('m').value - header['DSUN_REF'] = 149597870691 - - # selecting right layer for data - data = hdul[0].data[5, :, :] - - if instrument == 'kp512': - header = hdul[0].header - header['cunit1'] = 'arcsec' - header['cunit2'] = 'arcsec' - header['CDELT1'] = header['CDELT1A'] - header['CDELT2'] = header['CDELT2A'] - header['CRVAL1'] = 0 - header['CRVAL2'] = 0 - header['RSUN_OBS'] = header['EPH_R0 '] - header['CROTA2'] = 0 - header['CRPIX1'] = header['CRPIX1A'] - header['CRPIX2'] = header['CRPIX2A'] - header['PC2_1'] = 0 - header['PC1_2'] = 0 - header['RSUN_REF'] = 696000000 - - # Adding distance to header - t = Time(header['DATE-OBS']) - loc = EarthLocation.of_site('kpno') - with solar_system_ephemeris.set('builtin'): - sun = get_body('sun', t, loc) - header['DSUN_OBS'] = sun.distance.to('m').value - header['DSUN_REF'] = 149597870691 - - # selecting right layer for data - data = hdul[0].data[2, :, :] - - if instrument == 'mwo': - - file_name = file.name - - # Deconstruct Name to assess date - tmpPos = file_name.rfind('_') - - year = int(file_name[tmpPos - 6:tmpPos - 4]) - - # Adding century - if year < 1960: - year += 2000 - else: - year += 1900 - - month = int(file_name[tmpPos - 4:tmpPos - 2]) - day = int(file_name[tmpPos - 2:tmpPos]) - hr = int(file_name[tmpPos + 1:tmpPos + 3]) - 1 - mn = int(file_name[tmpPos + 3:tmpPos + 5]) - sc = 0 - - # Fix Times - if mn > 59: - mn = mn - 60 - hr = hr + 1 - - # Assemble date - if hr > 23: - tmpDate = datetime.datetime(year, month, day, hr - 24, mn, - sc) + datetime.timedelta(days=1) - else: - tmpDate = datetime.datetime(year, month, day, hr, mn, sc) - - header = hdul[0].header + header['CROTA2'] = -header['SOLAR_P0'] + header['CRVAL1'] = 0.000000 + header['CRVAL2'] = 0.000000 header['CUNIT1'] = 'arcsec' header['CUNIT2'] = 'arcsec' - header['CDELT1'] = header['DXB_IMG'] - header['CDELT2'] = header['DYB_IMG'] - header['CRVAL1'] = 0.0 - header['CRVAL2'] = 0.0 - header['RSUN_OBS'] = (header['R0']) * header['DXB_IMG'] - header['CROTA2'] = 0.0 - header['CRPIX1'] = header['X0'] - 0.5 - header['CRPIX2'] = header['Y0'] - 0.5 - header['T_OBS'] = tmpDate.strftime('%Y-%m-%dT%H-%M:00.0') - header['DATE-OBS'] = tmpDate.strftime('%Y-%m-%dT%H:%M:00.0') - header['DATE_OBS'] = tmpDate.strftime('%Y-%m-%dT%H:%M:00.0') + header['DSUN_OBS'] = header['OBS_DIST'] + header['DSUN_REF'] = 1 + + if 'DSUN_REF' not in header: + header['DSUN_REF'] = u.au.to('m') + + try: + header.pop('SOLAR_P0') + header.pop('OBS_DIST') + header.pop('OBS_R0') + except KeyError: + pass + + elif instrument == 'gong': + if len(header['DATE-OBS'])<22: + header['RSUN_OBS'] = header['RADIUS'] * 180 / np.pi * 60 * 60 header['RSUN_REF'] = 696000000 + header['CROTA2'] = 0 + header['CUNIT1'] = 'arcsec' + header['CUNIT2'] = 'arcsec' + header['DSUN_OBS'] = header['DISTANCE'] * 149597870691 + header['DSUN_REF'] = 149597870691 + header['cdelt1'] = 2.5534 + header['cdelt2'] = 2.5534 + header['CTYPE1'] = 'HPLN-TAN' header['CTYPE2'] = 'HPLT-TAN' - header['RSUN_REF'] = 696000000 - # Adding distance to header - t = Time(header['DATE-OBS'], format='isot') - loc = EarthLocation.of_site('mwo') - with solar_system_ephemeris.set('builtin'): - sun = get_body('sun', t, loc) - header['DSUN_OBS'] = sun.distance.to('m').value - header['DSUN_REF'] = 149597870691 + date = header['DATE-OBS'] + header['DATE-OBS'] = date[0:4] + '-' + date[5:7] + '-' + date[8:10] + 'T' +\ + header['TIME-OBS'][0:11] - # selecting right layer for data - data = hdul[0].data - - sun_map = Map(data, header) + sun_map = Map(data, header) return sun_map @@ -291,44 +169,52 @@ def scale_rotate(amap, target_scale=0.504273, target_factor=0): crop_map.meta['xscale'] = target_factor * target_scale crop_map.meta['yscale'] = target_factor * target_scale - return crop_map -def get_patch(amap, size): +def get_patches(amap, size): """ - create patches of dimension size * size with a defined stride. + create patches of dimension size * size with a stride of size. + Since stride is equals to size, there is no overlap Parameters ---------- - amap: sunpy map + amap : sunpy.map.Map + Input map to create tiles from - size: integer - size of each patch + size : int + Size of each patch Returns ------- - numpy array [num_patches, num_channel, size, size] - channels are magnetic field, radial distance relative to radius + numpy.array (num_patches, num_channel, size, size) + Channels are magnetic field, radial distance relative to radius """ array_radius = get_array_radius(amap) patches = image.extract_patches(amap.data, (size, size), extraction_step=size) - patches = patches.reshape([-1] + list((size, size))) + patches = patches.reshape(-1, size, size) patches_r = image.extract_patches(array_radius, (size, size), extraction_step=size) - patches_r = patches_r.reshape([-1] + list((size, size))) + patches_r = patches_r.reshape(-1, size, size) return np.stack([patches, patches_r], axis=1) def get_array_radius(amap): """ - Compute an array with the radial coordinate for each pixel - :param amap: - :return: (W, H) array + Compute an array with the radial coordinate for each pixel. + + Parameters + ---------- + amap : sunpy.map.Map + Input map + Returns + ------- + numpy.ndarray + Array full of radius values """ x, y = np.meshgrid(*[np.arange(v.value) for v in amap.dimensions]) * u.pixel hpc_coords = amap.pixel_to_world(x, y) @@ -337,11 +223,18 @@ def get_array_radius(amap): return array_radius -def get_image_from_array(list_patches): +def get_image_from_patches(list_patches): """ Reconstruct from a list of patches the full disk image - :param list_patches: - :return: + + Parameters + ---------- + list_patches : list of patches produced from `get_patches` + + Returns + ------- + numpy.ndarray + Reconstructed image """ out = np.array(list_patches) out_r = out.reshape(out.shape[0] * out.shape[1], out.shape[2], out.shape[3]) @@ -357,8 +250,7 @@ def get_image_from_array(list_patches): def plot_magnetogram(amap, file, scale=1, vmin=-2000, vmax=2000, cmap=plt.cm.get_cmap('hmimag')): """ Plot magnetogram - :param amap: - :return: (W, H) array + """ # Size definitions @@ -394,8 +286,8 @@ def plot_magnetogram(amap, file, scale=1, vmin=-2000, vmax=2000, cmap=plt.cm.get ax1 = fig.add_axes([ppadh + ppxx, ppadv + ppxy, ppxx, ppxy]) ax1.imshow(amap.data, vmin=vmin, vmax=vmax, cmap=cmap, origin='lower') ax1.set_axis_off() - ax1.text(0.99, 0.99, 'ML Output', horizontalalignment='right', verticalalignment='top', color='k', - transform=ax1.transAxes) + ax1.text(0.99, 0.99, 'ML Output', horizontalalignment='right', verticalalignment='top', + color='k', transform=ax1.transAxes) fig.savefig(file, bbox_inches='tight', dpi=dpi, pad_inches=0) diff --git a/source/dataset.py b/source/dataset.py index a4c786d..997aa89 100644 --- a/source/dataset.py +++ b/source/dataset.py @@ -1,35 +1,44 @@ -import numpy as np - +from datetime import datetime -from torch.utils.data import Dataset -from torch import from_numpy +import numpy as np from sunpy.map import Map -from datetime import datetime +from torch import from_numpy +from torch.utils.data import Dataset -from source.data_utils import get_patch, get_array_radius, map_prep, scale_rotate +from source.data_utils import get_patches, get_array_radius, map_prep, scale_rotate class FitsFileDataset(Dataset): """ Construct a dataset of patches from a fits file """ + def __init__(self, file, size, norm, instrument, rescale, upscale_factor): map = map_prep(file, instrument) - map.data[:] = map.data[:]/norm + map.data[:] = map.data[:] / norm # Detecting need for rescale - if rescale and np.abs(1 - (map.meta['cdelt1']/0.504273)/upscale_factor) > 0.01: + if rescale and np.abs(1 - (map.meta['cdelt1'] / 0.504273) / upscale_factor) > 0.01: map = scale_rotate(map, target_factor=upscale_factor) - self.data = get_patch(map, size) + self.data = get_patches(map, size) self.map = map + def __getitem__(self, idx): """ - Create torch tensor from patch with id idx - :param idx: - :return: tensor (W, H, 2) + Get a patch + + Parameters + ---------- + idx : int + Index + + Returns + ------- + torch.tensor + patch """ patch = self.data[idx, ...] patch[patch != patch] = 0 @@ -39,17 +48,33 @@ def __getitem__(self, idx): def __len__(self): return self.data.shape[0] - def create_new_map(self, new_data, scale_factor, add_noise, model_name, config_data, padding): + def create_new_map(self, new_data, scale_factor, add_noise, model_name, config_data, padding): """ Adjust header to match upscaling factor and add new keywords - :return: + + Parameters + ---------- + new_data : + Superresolved map data + scale_factor : int + Upscaling factor + add_noise : bool + True add noise to superreolved image + model_name : str + Name of the model use + config_data : dict + Configuration parameters + padding : + """ new_meta = self.map.meta.copy() # Changing scale and center - new_meta['crpix1'] = new_meta['crpix1'] - self.map.data.shape[0] / 2 + self.map.data.shape[0] * scale_factor / 2 - new_meta['crpix2'] = new_meta['crpix2'] - self.map.data.shape[1] / 2 + self.map.data.shape[1] * scale_factor / 2 + new_meta['crpix1'] = (new_meta['crpix1'] - self.map.data.shape[0] / 2 + + self.map.data.shape[0] * scale_factor / 2) + new_meta['crpix2'] = (new_meta['crpix2'] - self.map.data.shape[1] / 2 + + self.map.data.shape[1] * scale_factor / 2) new_meta['cdelt1'] = new_meta['cdelt1'] / scale_factor new_meta['cdelt2'] = new_meta['cdelt2'] / scale_factor @@ -68,12 +93,11 @@ def create_new_map(self, new_data, scale_factor, add_noise, model_name, config_ # Changing data info new_meta['datamin'] = np.nanmin(self.map.data) new_meta['datamax'] = np.nanmax(self.map.data) - new_meta['data_rms'] = np.sqrt(np.nanmean(self.map.data**2)) + new_meta['data_rms'] = np.sqrt(np.nanmean(self.map.data ** 2)) new_meta['datamean'] = np.nanmean(self.map.data) new_meta['datamedn'] = np.nanmedian(self.map.data) new_meta['dataskew'] = np.nanmedian(self.map.data) - # Add keywords related to conversion try: new_meta['instrume'] = new_meta['instrume'] + '-2HMI_HR' @@ -83,7 +107,8 @@ def create_new_map(self, new_data, scale_factor, add_noise, model_name, config_ new_meta['hrkey1'] = '---------------- HR ML Keywords Section ----------------' new_meta['date-ml'] = str(datetime.utcnow()) new_meta['nn-model'] = model_name - new_meta['loss'] = ', '.join('{!s}={!r}'.format(key, val) for (key, val) in config_data['loss'].items()) + new_meta['loss'] = ', '.join( + '{!s}={!r}'.format(key, val) for (key, val) in config_data['loss'].items()) new_meta['conv_doi'] = 'https://doi.org/10.5281/zenodo.3750372' new_meta['hrkey2'] = '---------------- HR ML Keywords Section ----------------' @@ -97,7 +122,3 @@ def create_new_map(self, new_data, scale_factor, add_noise, model_name, config_ new_map.data[array_radius >= 1] = padding return new_map - - - - diff --git a/source/load.py b/source/load.py index 2417357..cdb3794 100644 --- a/source/load.py +++ b/source/load.py @@ -7,9 +7,11 @@ logger = get_logger(__name__) + def load_from_google_cloud(run_name, epoch, model): """ - Construct a torch model from pe-trained model run_name stored on goolge cloud + Construct a torch model from pe-trained model run_name stored on goolge cloud. + :param run_name: string :param epoch: int :param model: torch model diff --git a/source/utils.py b/source/utils.py index ec65e46..38a0bbc 100644 --- a/source/utils.py +++ b/source/utils.py @@ -3,10 +3,10 @@ import warnings - def get_logger(name): """ - Return a logger for current module + Return a logger for current module. + Returns ------- @@ -30,6 +30,7 @@ def get_logger(name): return logger + def disable_warnings(): """ Disable printing of warnings @@ -39,6 +40,3 @@ def disable_warnings(): None """ warnings.simplefilter("ignore") - - - \ No newline at end of file From 5e523eed031c8d0054e021aa2a8a1fe21944bc7c Mon Sep 17 00:00:00 2001 From: Shane Maloney Date: Sun, 26 Apr 2020 18:31:00 +0100 Subject: [PATCH 2/2] More refacting * Might work on windows now * Sperate CLI from funcions * Docstrings * Pass config as dict * Address lint and format warnings and errors --- convert2HMI.py | 245 ++++++++++++++++++++++++++++--------------- requirements.txt | 2 - source/data_utils.py | 78 ++++++++------ source/dataset.py | 60 +++++------ source/load.py | 45 -------- source/utils.py | 7 +- 6 files changed, 241 insertions(+), 196 deletions(-) delete mode 100644 source/load.py diff --git a/convert2HMI.py b/convert2HMI.py index 74d4915..33342d7 100644 --- a/convert2HMI.py +++ b/convert2HMI.py @@ -1,131 +1,208 @@ import argparse -import os +import logging +from itertools import chain +from pathlib import Path import numpy as np import torch import yaml from torch.utils.data.dataloader import DataLoader -from source.data_utils import get_array_radius, get_image_from_patches, plot_magnetogram -from source.dataset import FitsFileDataset from source.models.model_manager import BaseScaler from source.utils import get_logger -if __name__ == '__main__': - logger = get_logger(__name__) - parser = argparse.ArgumentParser() - parser.add_argument('--instrument', required=True, choices=['mdi', 'gong']) - parser.add_argument('--data_path') - parser.add_argument('--destination') - parser.add_argument('--add_noise') - parser.add_argument('--plot', action='store_true') - parser.add_argument('--overwrite', action='store_true') - parser.add_argument('--use_patches', action='store_true') - parser.add_argument('--zero_outside', action='store_true') - parser.add_argument('--no_rescale', action='store_true') - - args = parser.parse_args() - instrument = args.instrument.lower() - - if args.instrument == 'mdi': - run = 'checkpoints/mdi/20200312194454_HighResNet_RPRCDO_SSIMGradHistLoss_mdi_19' - elif args.instrument == 'gong': - run = 'checkpoints/gong/20200321142757_HighResNet_RPRCDO_SSIMGradHistLoss_gong_1' - with open(run + '.yml', 'r') as stream: +def get_config(instrument, fulldisk, zero_outside, add_noise, no_rescale, **kwargs): + """ + Get config object setting values passed. + + Parameters + ---------- + instrument : str + Instrument name + fulldisk : bool + Fulldisk based inference, default is patch based + zero_outside : + Set region outside solar radius to zeros instead of default nan values + add_noise : float (optional) + Scale or standard deviation of noise to add + no_rescale : bool (optional) + Disable rescaling + + Returns + ------- + tuple + Run and config dict + """ + if instrument == 'mdi': + run_dir = Path('checkpoints/mdi/20200312194454_HighResNet_RPRCDO_SSIMGradHistLoss_mdi_19') + elif instrument == 'gong': + run_dir = Path('checkpoints/gong/20200321142757_HighResNet_RPRCDO_SSIMGradHistLoss_gong_1') + + with run_dir.with_suffix('.yml').open() as stream: config_data = yaml.load(stream, Loader=yaml.SafeLoader) + config_data['cli'] = {'fulldisk': fulldisk, + 'zero_outside': zero_outside, + 'add_noise': add_noise} + + config_data['instrument'] = instrument data_config = config_data['data'] - norm = 3500 - if 'normalisation' in data_config.keys(): - norm = data_config['normalisation'] + + if 'normalisation' not in data_config.keys(): + data_config['normalisation'] = 3500.0 padding = np.nan - if args.zero_outside: + if zero_outside: padding = 0 + data_config['padding'] = padding rescale = True - if args.no_rescale: + if no_rescale: rescale = False + data_config['rescale'] = rescale net_config = config_data['net'] - model_name = net_config['name'] - upscale_factor = 4 - if 'upscale_factor' in net_config.keys(): - upscale_factor = net_config['upscale_factor'] - - model = BaseScaler.from_dict(config_data) - - device = torch.device("cpu") - model = model.to(device) - - checkpoint = torch.load(run, map_location='cpu') + if 'upscale_factor' not in net_config.keys(): + net_config['upscale_factor'] = 4 + + return run_dir, config_data + + +def get_model(run, config): + """ + Get a model based on the run and config data + + Parameters + ---------- + run : pathlib.Path + Path to run directory + config : dict + Config data + + Returns + ------- + source.model.model_manger.TemplateModel + The model + """ + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + config['device'] = device + amodel = BaseScaler.from_dict(config) + amodel = amodel.to(device) + checkpoint = torch.load(run.as_posix(), map_location='cpu') state_dict = {} for key, value in checkpoint['model_state_dict'].items(): if key.startswith('module'): state_dict['.'.join(key.split('.')[1:])] = value else: state_dict[key] = value + amodel.load_state_dict(state_dict) - model.load_state_dict(state_dict) + return amodel - list_of_files = [] - for (dirpath, dirnames, filenames) in os.walk(args.data_path): - list_of_files += [os.path.join(dirpath, file) for file in filenames if - file.endswith('.fits') or file.endswith('.fits.gz')] - os.makedirs(args.destination, exist_ok=True) +def convert(in_file, out_file, config, patchsize=32): + """ + Convert a file to HMI - for file in list_of_files: + Parameters + ---------- + in_file : pathlib.Path + Input fits file + out_file : pathlib.Path + Output fits file + config : dict + Configuration dictionary + patchsize : int + Size of the patches created - logger.info(f'Processing {file}') + Returns + ------- - output_file = args.destination + '/' + '.'.join( - file.split('/')[-1].split('.gz')[0].split('.')[0:-1]) - if os.path.exists(output_file + '_HR.fits') and not args.overwrite: - logger.info(f'{file} already exists') + """ + # Really slow imports so only import if we reach the point where it is needed + from source.dataset import FitsFileDataset + from source.data_utils import get_array_radius, get_image_from_patches, plot_magnetogram - else: + norm = config['data']['normalisation'] + device = config['device'] + fulldisk = config['cli']['fulldisk'] - file_dset = FitsFileDataset(file, 32, norm, instrument, rescale, upscale_factor) + file_dset = FitsFileDataset(in_file, patchsize, config) + inferred = None + # Try full disk + if fulldisk: + try: + logger.info('Attempting full disk inference...') + in_fd = np.stack([file_dset.map.data, get_array_radius(file_dset.map)], axis=0) + inferred = model.forward( + torch.from_numpy(in_fd[None]).to(device).float()).detach().numpy()[ + 0, ...] * norm + logger.info('Success.') - # Try full disk - success_sw = False - if not args.use_patches: - success_sw = True - try: - logger.info(f'Attempting full disk inference...') - in_fd = np.stack([file_dset.map.data, get_array_radius(file_dset.map)], axis=0) - inferred = model.forward( - torch.from_numpy(in_fd[None]).to(device).float()).detach().numpy()[ - 0, ...] * norm - logger.info(f'Success.') + except Exception: + logger.info('Full disk inference failed', exc_info=True) - except Exception as e: - logger.info(f'Failure. {e}') - success_sw = False + else: + logger.info('Attempting inference on patches...') + dataloader = DataLoader(file_dset, batch_size=8, shuffle=False) - if not success_sw or args.use_patches: - logger.info(f'Attempting inference on patches...') - dataloader = DataLoader(file_dset, batch_size=8, shuffle=False) + output_patches = [] - output_patches = [] + for patch in dataloader: + patch.to(device) + output = model.forward(patch) * norm - for input in dataloader: - input = input.to(device) - output = model.forward(input) * norm + output_patches.append(output.detach().cpu().numpy()) - output_patches.append(output.detach().cpu().numpy()) + inferred = get_image_from_patches(output_patches) + logger.info(f'Success.') - inferred = get_image_from_patches(output_patches) - logger.info(f'Success.') + if inferred: + inferred_map = file_dset.create_new_map(inferred, model.name) + inferred_map.save(out_file.as_posix(), overwrite=True) - inferred_map = file_dset.create_new_map(inferred, upscale_factor, args.add_noise, - model_name, config_data, padding) - inferred_map.save(output_file + '_HR.fits', overwrite=True) + if args.plot: + plot_magnetogram(inferred_map, out_file.with_suffix('.png')) - if args.plot: - plot_magnetogram(inferred_map, output_file + '_HR.png') + del inferred_map - del inferred_map + +if __name__ == '__main__': + logging.root.setLevel('INFO') + logger = get_logger(__name__) + + parser = argparse.ArgumentParser() + parser.add_argument('--instrument', required=True, choices=['mdi', 'gong']) + parser.add_argument('--source_dir', required=True, type=str) + parser.add_argument('--destination_dir', required=True, type=str) + parser.add_argument('--add_noise', type=float) + parser.add_argument('--plot', action='store_true') + parser.add_argument('--overwrite', action='store_true') + parser.add_argument('--fulldisk', action='store_true') + parser.add_argument('--zero_outside', action='store_true') + parser.add_argument('--no_rescale', action='store_true') + + args = parser.parse_args() + + source_dir = Path(args.source_dir) + destination_dir = Path(args.destination_dir) + overwrite = args.overwrite + + checkpoint_dir, config_data = get_config(**vars(args)) + + model = get_model(checkpoint_dir, config_data) + + source_files = chain(source_dir.rglob('*.fits'), source_dir.rglob('*.fits.gz')) + + destination_dir.mkdir(exist_ok=True, parents=True) + + for file in source_files: + logger.info(f'Processing {file}') + out_path = destination_dir / (file.stem + '_HR.fits') + + if out_path.exists() and not overwrite: + logger.info(f'{file} already exists') + else: + convert(file, out_path, config_data) diff --git a/requirements.txt b/requirements.txt index 5a3baad..d3e54b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,5 @@ torch==1.1.0 torchvision==0.3.0 tensorboard==1.14.0 tensorboardX==1.8 -gitpython==3.1.0 -tensorflow==1.15.2 scikit-learn==0.21.3 gcsfs==0.2.3 \ No newline at end of file diff --git a/source/data_utils.py b/source/data_utils.py index a3bbbed..80c9ef8 100644 --- a/source/data_utils.py +++ b/source/data_utils.py @@ -1,13 +1,12 @@ import numpy as np -import math from astropy import units as u -from sklearn.feature_extraction import image from astropy.coordinates import SkyCoord -from sunpy.map import Map from astropy.io import fits - -import matplotlib.pyplot as plt +from matplotlib import pyplot as plt +from sklearn.feature_extraction import image +from sunpy.cm import cm as scm +from sunpy.map import Map from source.utils import disable_warnings, get_logger @@ -15,22 +14,21 @@ logger = get_logger(__name__) -def map_prep(file, instrument, *keyward_args): +def map_prep(file, instrument): """ - Return a processed hmi magnetogram and path + Return a processed magnetogram Parameters ---------- - file : file desctiptor - instrument: string - - keyward_args : + file : str + File to prep + instrument : str + Instrument name Returns - - tuple : preped map and filepath ------- - + sunpy.map.Map + Preped map and filepath """ # Open fits file as HUDL and fix header @@ -70,7 +68,7 @@ def map_prep(file, instrument, *keyward_args): pass elif instrument == 'gong': - if len(header['DATE-OBS'])<22: + if len(header['DATE-OBS']) < 22: header['RSUN_OBS'] = header['RADIUS'] * 180 / np.pi * 60 * 60 header['RSUN_REF'] = 696000000 header['CROTA2'] = 0 @@ -86,7 +84,7 @@ def map_prep(file, instrument, *keyward_args): date = header['DATE-OBS'] header['DATE-OBS'] = date[0:4] + '-' + date[5:7] + '-' + date[8:10] + 'T' +\ - header['TIME-OBS'][0:11] + header['TIME-OBS'][0:11] sun_map = Map(data, header) @@ -98,30 +96,37 @@ def scale_rotate(amap, target_scale=0.504273, target_factor=0): Parameters ---------- - amap + amap : sunpy.map.Map + Input map + target_factor : float + + target_scale : int + Returns ------- - + sunpy.map.Map + Scaled and rotated map """ scalex = amap.meta['cdelt1'] scaley = amap.meta['cdelt2'] + if scalex != scaley: + raise ValueError('Square pixels expected') # Calculate target factor if not provided if target_factor == 0: target_factor = np.round(scalex / target_scale) ratio_plate = target_factor * target_scale / scalex - # logger.info(np.round(scalex / target_scale) / scalex * target_scale) + logger.debug(np.round(scalex / target_scale) / scalex * target_scale) ratio_dist = amap.meta['dsun_obs'] / amap.meta['dsun_ref'] - # logger.info(ratio_dist) + logger.debug(ratio_dist) # Pad image, if necessary new_shape = int(4096/target_factor) # Reform map to new size if original shape is too small - if new_shape > amap.data.shape[0]: new_fov = np.zeros((new_shape, new_shape)) * np.nan @@ -147,7 +152,7 @@ def scale_rotate(amap, target_scale=0.504273, target_factor=0): x_scale = ((rot_map.scale.axis1 * amap.dimensions.x) / 2) y_scale = ((rot_map.scale.axis2 * amap.dimensions.y) / 2) - # logger.info(f'x-scale {x_scale}, y-scale {y_scale}') + logger.debug(f'x-scale {x_scale}, y-scale {y_scale}') if x_scale != y_scale: logger.error(f'x-scale: {x_scale} and y-scale {y_scale} do not match') @@ -239,7 +244,7 @@ def get_image_from_patches(list_patches): out = np.array(list_patches) out_r = out.reshape(out.shape[0] * out.shape[1], out.shape[2], out.shape[3]) - size = int(math.sqrt(out_r.shape[0])) + size = int(np.sqrt(out_r.shape[0])) out_array = np.array_split(out_r, size, axis=0) out_array = np.concatenate(out_array, axis=1) out_array = np.concatenate(out_array, axis=1) @@ -247,10 +252,24 @@ def get_image_from_patches(list_patches): return out_array -def plot_magnetogram(amap, file, scale=1, vmin=-2000, vmax=2000, cmap=plt.cm.get_cmap('hmimag')): +def plot_magnetogram(amap, file, scale=1, vmin=-2000, vmax=2000, cmap=scm.hmimag): """ Plot magnetogram + Parameters + ---------- + amap : sunpy.map.Map + Magnetogram map to plot + file : pathlib.Path + Filename to save plot as + scale : + + vmin : float + Min value for color map + vmax : float + Max value for color map + cmap : matplotlib.colors.Colormap + Colormap """ # Size definitions @@ -275,11 +294,13 @@ def plot_magnetogram(amap, file, scale=1, vmin=-2000, vmax=2000, cmap=plt.cm.get ppxx = pxx / fszh # Horizontal size of each panel in relative units ppxy = pxy / fszv # Vertical size of each panel in relative units ppadv = padv / fszv # Vertical padding in relative units - ppadv2 = padv2 / fszv # Vertical padding in relative units + # Never used so commented out + # ppadv2 = padv2 / fszv # Vertical padding in relative units ppadh = padh / fszh # Horizontal padding the edge of the figure in relative units - ppadh2 = padh2 / fszh # Horizontal padding between panels in relative units + # Never used so commented out + # ppadh2 = padh2 / fszh # Horizontal padding between panels in relative units - ## Start Figure + # Start Figure fig = plt.figure(figsize=(fszh / dpi, fszv / dpi), dpi=dpi) # ## Add Perihelion @@ -290,6 +311,3 @@ def plot_magnetogram(amap, file, scale=1, vmin=-2000, vmax=2000, cmap=plt.cm.get color='k', transform=ax1.transAxes) fig.savefig(file, bbox_inches='tight', dpi=dpi, pad_inches=0) - - - diff --git a/source/dataset.py b/source/dataset.py index 997aa89..e0e2d70 100644 --- a/source/dataset.py +++ b/source/dataset.py @@ -13,18 +13,25 @@ class FitsFileDataset(Dataset): Construct a dataset of patches from a fits file """ - def __init__(self, file, size, norm, instrument, rescale, upscale_factor): + def __init__(self, file, size, config_data): - map = map_prep(file, instrument) - map.data[:] = map.data[:] / norm + self.instrument = config_data['instrument'] + self.norm = config_data['data']['normalisation'] + self.rescale = config_data['data']['rescale'] + self.upscale_factor = config_data['net']['upscale_factor'] + self.padding = config_data['data']['padding'] + self.loss = config_data['loss'] + self.add_noise = config_data['cli']['add_noise'] - # Detecting need for rescale - if rescale and np.abs(1 - (map.meta['cdelt1'] / 0.504273) / upscale_factor) > 0.01: - map = scale_rotate(map, target_factor=upscale_factor) + amap = map_prep(file, self.instrument) + amap.data[:] = amap.data[:] / self.norm - self.data = get_patches(map, size) - self.map = map + # Detecting need for rescale + if self.rescale and np.abs(1 - (amap.meta['cdelt1']/0.504273) / self.upscale_factor) > 0.01: + amap = scale_rotate(amap, target_factor=self.upscale_factor) + self.data = get_patches(amap, size) + self.map = amap def __getitem__(self, idx): """ @@ -48,7 +55,7 @@ def __getitem__(self, idx): def __len__(self): return self.data.shape[0] - def create_new_map(self, new_data, scale_factor, add_noise, model_name, config_data, padding): + def create_new_map(self, new_data, model_name): """ Adjust header to match upscaling factor and add new keywords @@ -56,34 +63,25 @@ def create_new_map(self, new_data, scale_factor, add_noise, model_name, config_d ---------- new_data : Superresolved map data - scale_factor : int - Upscaling factor - add_noise : bool - True add noise to superreolved image model_name : str Name of the model use - config_data : dict - Configuration parameters - padding : - """ new_meta = self.map.meta.copy() # Changing scale and center new_meta['crpix1'] = (new_meta['crpix1'] - self.map.data.shape[0] / 2 - + self.map.data.shape[0] * scale_factor / 2) + + self.map.data.shape[0] * self.upscale_factor / 2) new_meta['crpix2'] = (new_meta['crpix2'] - self.map.data.shape[1] / 2 - + self.map.data.shape[1] * scale_factor / 2) - new_meta['cdelt1'] = new_meta['cdelt1'] / scale_factor - new_meta['cdelt2'] = new_meta['cdelt2'] / scale_factor + + self.map.data.shape[1] * self.upscale_factor / 2) + new_meta['cdelt1'] = new_meta['cdelt1'] / self.upscale_factor + new_meta['cdelt2'] = new_meta['cdelt2'] / self.upscale_factor try: - new_meta['im_scale'] = new_meta['im_scale'] / scale_factor - new_meta['fd_scale'] = new_meta['im_scale'] / scale_factor - new_meta['xscale'] = new_meta['xscale'] / scale_factor - new_meta['yscale'] = new_meta['yscale'] / scale_factor - + new_meta['im_scale'] = new_meta['im_scale'] / self.upscale_factor + new_meta['fd_scale'] = new_meta['im_scale'] / self.upscale_factor + new_meta['xscale'] = new_meta['xscale'] / self.upscale_factor + new_meta['yscale'] = new_meta['yscale'] / self.upscale_factor except: pass @@ -101,24 +99,24 @@ def create_new_map(self, new_data, scale_factor, add_noise, model_name, config_d # Add keywords related to conversion try: new_meta['instrume'] = new_meta['instrume'] + '-2HMI_HR' - except: + except KeyError: new_meta['instrume'] = new_meta['telescop'] + '-2HMI_HR' new_meta['hrkey1'] = '---------------- HR ML Keywords Section ----------------' new_meta['date-ml'] = str(datetime.utcnow()) new_meta['nn-model'] = model_name new_meta['loss'] = ', '.join( - '{!s}={!r}'.format(key, val) for (key, val) in config_data['loss'].items()) + '{!s}={!r}'.format(key, val) for (key, val) in self.loss.items()) new_meta['conv_doi'] = 'https://doi.org/10.5281/zenodo.3750372' new_meta['hrkey2'] = '---------------- HR ML Keywords Section ----------------' new_map = Map(new_data, new_meta) - if add_noise: - noise = np.random.normal(loc=0.0, scale=add_noise, size=new_map.data.shape) + if self.add_noise: + noise = np.random.normal(loc=0.0, scale=self.add_noise, size=new_map.data.shape) new_map.data[:] = new_map.data[:] + noise[:] array_radius = get_array_radius(new_map) - new_map.data[array_radius >= 1] = padding + new_map.data[array_radius >= 1] = self.padding return new_map diff --git a/source/load.py b/source/load.py deleted file mode 100644 index cdb3794..0000000 --- a/source/load.py +++ /dev/null @@ -1,45 +0,0 @@ -import os - -import torch - -from google.cloud import storage -from source.utils import get_logger - -logger = get_logger(__name__) - - -def load_from_google_cloud(run_name, epoch, model): - """ - Construct a torch model from pe-trained model run_name stored on goolge cloud. - - :param run_name: string - :param epoch: int - :param model: torch model - :return: torch model with pre trained parameters - """ - - gcs_storage_client = storage.Client() - - bucket = gcs_storage_client.bucket('fdl-mag-experiments') - blob = bucket.blob(f'checkpoints/{run_name}/epoch_{epoch}') - - if not os.path.exists(f'checkpoints/{run_name}'): - logger.info(f'Creating checkpoint folder: checkpoints/{run_name}') - os.makedirs(f'checkpoints/{run_name}', exist_ok=True) - if not os.path.exists(f'checkpoints/{run_name}/epoch_{epoch}'): - logger.info(f'Downloading checkpoint: {epoch}') - blob.download_to_filename(f'checkpoints/{run_name}/epoch_{epoch}') - - checkpoint = torch.load(f'checkpoints/{run_name}/epoch_{epoch}', map_location='cpu') - logger.info(f'Loading Model: fdl-mag-experiments/checkpoints/{run_name}/epoch_{epoch}') - - if list(checkpoint['model_state_dict'].keys())[0].split('.')[0] == 'module': - state_dict = {} - for key, value in checkpoint['model_state_dict'].items(): - state_dict['.'.join(key.split('.')[1:])] = value - - model.load_state_dict(state_dict) - else: - model.load_state_dict(checkpoint['model_state_dict']) - - return model diff --git a/source/utils.py b/source/utils.py index 38a0bbc..2cea80d 100644 --- a/source/utils.py +++ b/source/utils.py @@ -9,16 +9,15 @@ def get_logger(name): Returns ------- - - logger : logger instance - + logging.Logger + Logger instance """ logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s %(name)s: %(message)s", datefmt="%Y-%m-%d - %H:%M:%S") console = logging.StreamHandler(sys.stdout) - console.setLevel(logging.DEBUG) + console.setLevel(logging.INFO) console.setFormatter(formatter) logfile = logging.FileHandler('run.log', 'w')