diff --git a/convert2HMI.py b/convert2HMI.py index 153a61c..33342d7 100644 --- a/convert2HMI.py +++ b/convert2HMI.py @@ -1,169 +1,208 @@ -import os -import sys - import argparse -import yaml import logging +from itertools import chain +from pathlib import Path import numpy as np -from astropy.io import fits -import astropy.units as u -import sunpy.map - import torch +import yaml from torch.utils.data.dataloader import DataLoader from source.models.model_manager import BaseScaler -from source.dataset import FitsFileDataset -from source.data_utils import get_array_radius, get_image_from_array, plot_magnetogram +from source.utils import get_logger -def get_logger(name): + +def get_config(instrument, fulldisk, zero_outside, add_noise, no_rescale, **kwargs): """ - Return a logger for current module + Get config object setting values passed. + + Parameters + ---------- + instrument : str + Instrument name + fulldisk : bool + Fulldisk based inference, default is patch based + zero_outside : + Set region outside solar radius to zeros instead of default nan values + add_noise : float (optional) + Scale or standard deviation of noise to add + no_rescale : bool (optional) + Disable rescaling + Returns ------- - - logger : logger instance - + tuple + Run and config dict """ - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s %(name)s: %(message)s", - datefmt="%Y-%m-%d - %H:%M:%S") - console = logging.StreamHandler(sys.stdout) - console.setLevel(logging.DEBUG) - console.setFormatter(formatter) - - logfile = logging.FileHandler('run.log', 'w') - logfile.setLevel(logging.DEBUG) - logfile.setFormatter(formatter) - - logger.addHandler(console) - logger.addHandler(logfile) - - return logger - -if __name__ == '__main__': - logger = get_logger(__name__) - parser = argparse.ArgumentParser() - parser.add_argument('--instrument', required=True) - parser.add_argument('--data_path') - parser.add_argument('--destination') - parser.add_argument('--add_noise') - parser.add_argument('--plot', action='store_true') - parser.add_argument('--overwrite', action='store_true') - parser.add_argument('--use_patches', action='store_true') - parser.add_argument('--zero_outside', action='store_true') - parser.add_argument('--no_rescale', action='store_true') - - - args = parser.parse_args() - instrument = args.instrument.lower() - if instrument == 'mdi': - run = 'checkpoints/mdi/20200312194454_HighResNet_RPRCDO_SSIMGradHistLoss_mdi_19' + run_dir = Path('checkpoints/mdi/20200312194454_HighResNet_RPRCDO_SSIMGradHistLoss_mdi_19') elif instrument == 'gong': - run = 'checkpoints/gong/20200321142757_HighResNet_RPRCDO_SSIMGradHistLoss_gong_1' - else: - raise RuntimeError(f'mdi and gong are the only valid instruments.') + run_dir = Path('checkpoints/gong/20200321142757_HighResNet_RPRCDO_SSIMGradHistLoss_gong_1') - with open(run + '.yml', 'r') as stream: + with run_dir.with_suffix('.yml').open() as stream: config_data = yaml.load(stream, Loader=yaml.SafeLoader) + config_data['cli'] = {'fulldisk': fulldisk, + 'zero_outside': zero_outside, + 'add_noise': add_noise} + + config_data['instrument'] = instrument data_config = config_data['data'] - norm = 3500 - if 'normalisation' in data_config.keys(): - norm = data_config['normalisation'] + + if 'normalisation' not in data_config.keys(): + data_config['normalisation'] = 3500.0 padding = np.nan - if args.zero_outside: + if zero_outside: padding = 0 + data_config['padding'] = padding rescale = True - if args.no_rescale: + if no_rescale: rescale = False + data_config['rescale'] = rescale net_config = config_data['net'] - model_name = net_config['name'] - upscale_factor = 4 - if 'upscale_factor' in net_config.keys(): - upscale_factor = net_config['upscale_factor'] - model = BaseScaler.from_dict(config_data) + if 'upscale_factor' not in net_config.keys(): + net_config['upscale_factor'] = 4 - device = torch.device("cpu") - model = model.to(device) + return run_dir, config_data + + +def get_model(run, config): + """ + Get a model based on the run and config data - checkpoint = torch.load(run, map_location='cpu') - try: + Parameters + ---------- + run : pathlib.Path + Path to run directory + config : dict + Config data + Returns + ------- + source.model.model_manger.TemplateModel + The model + """ + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + config['device'] = device + amodel = BaseScaler.from_dict(config) + amodel = amodel.to(device) + checkpoint = torch.load(run.as_posix(), map_location='cpu') + state_dict = {} + for key, value in checkpoint['model_state_dict'].items(): + if key.startswith('module'): + state_dict['.'.join(key.split('.')[1:])] = value + else: + state_dict[key] = value + amodel.load_state_dict(state_dict) + + return amodel + + +def convert(in_file, out_file, config, patchsize=32): + """ + Convert a file to HMI + + Parameters + ---------- + in_file : pathlib.Path + Input fits file + out_file : pathlib.Path + Output fits file + config : dict + Configuration dictionary + patchsize : int + Size of the patches created + + Returns + ------- + + """ + # Really slow imports so only import if we reach the point where it is needed + from source.dataset import FitsFileDataset + from source.data_utils import get_array_radius, get_image_from_patches, plot_magnetogram + + norm = config['data']['normalisation'] + device = config['device'] + fulldisk = config['cli']['fulldisk'] + + file_dset = FitsFileDataset(in_file, patchsize, config) + inferred = None + # Try full disk + if fulldisk: try: - model.load_state_dict(checkpoint['model_state_dict']) + logger.info('Attempting full disk inference...') + in_fd = np.stack([file_dset.map.data, get_array_radius(file_dset.map)], axis=0) + inferred = model.forward( + torch.from_numpy(in_fd[None]).to(device).float()).detach().numpy()[ + 0, ...] * norm + logger.info('Success.') - except: - state_dict = {} - for key, value in checkpoint['model_state_dict'].items(): - state_dict['.'.join(key.split('.')[1:])] = value - model.load_state_dict(state_dict) - except: - state_dict = {} - for key, value in checkpoint['model_state_dict'].items(): - state_dict['.'.join(np.append(['module'], key.split('.')[0:]))] = value - model.load_state_dict(state_dict) + except Exception: + logger.info('Full disk inference failed', exc_info=True) + else: + logger.info('Attempting inference on patches...') + dataloader = DataLoader(file_dset, batch_size=8, shuffle=False) - list_of_files = [] - for (dirpath, dirnames, filenames) in os.walk(args.data_path): - list_of_files += [os.path.join(dirpath, file) for file in filenames if file.endswith('.fits') or file.endswith('.fits.gz')] + output_patches = [] - os.makedirs(args.destination, exist_ok=True) + for patch in dataloader: + patch.to(device) + output = model.forward(patch) * norm - for file in list_of_files: + output_patches.append(output.detach().cpu().numpy()) - logger.info(f'Processing {file}') + inferred = get_image_from_patches(output_patches) + logger.info(f'Success.') - output_file = args.destination + '/' + '.'.join(file.split('/')[-1].split('.gz')[0].split('.')[0:-1]) - if os.path.exists(output_file + '_HR.fits') and not args.overwrite: - logger.info(f'{file} already exists') + if inferred: + inferred_map = file_dset.create_new_map(inferred, model.name) + inferred_map.save(out_file.as_posix(), overwrite=True) - else: + if args.plot: + plot_magnetogram(inferred_map, out_file.with_suffix('.png')) - file_dset = FitsFileDataset(file, 32, norm, instrument, rescale, upscale_factor) + del inferred_map - # Try full disk - success_sw = False - if not args.use_patches: - success_sw = True - try: - logger.info(f'Attempting full disk inference...') - in_fd = np.stack([file_dset.map.data, get_array_radius(file_dset.map)], axis=0) - inferred = model.forward(torch.from_numpy(in_fd[None]).to(device).float()).detach().numpy()[0,...]*norm - logger.info(f'Success.') - except Exception as e: - logger.info(f'Failure. {e}') - success_sw = False +if __name__ == '__main__': + logging.root.setLevel('INFO') + logger = get_logger(__name__) - if not success_sw or args.use_patches: - logger.info(f'Attempting inference on patches...') - dataloader = DataLoader(file_dset, batch_size=8, shuffle=False) + parser = argparse.ArgumentParser() + parser.add_argument('--instrument', required=True, choices=['mdi', 'gong']) + parser.add_argument('--source_dir', required=True, type=str) + parser.add_argument('--destination_dir', required=True, type=str) + parser.add_argument('--add_noise', type=float) + parser.add_argument('--plot', action='store_true') + parser.add_argument('--overwrite', action='store_true') + parser.add_argument('--fulldisk', action='store_true') + parser.add_argument('--zero_outside', action='store_true') + parser.add_argument('--no_rescale', action='store_true') - output_patches = [] + args = parser.parse_args() - for input in dataloader: + source_dir = Path(args.source_dir) + destination_dir = Path(args.destination_dir) + overwrite = args.overwrite - input = input.to(device) - output = model.forward(input) * norm + checkpoint_dir, config_data = get_config(**vars(args)) - output_patches.append(output.detach().cpu().numpy()) + model = get_model(checkpoint_dir, config_data) - inferred = get_image_from_array(output_patches) - logger.info(f'Success.') + source_files = chain(source_dir.rglob('*.fits'), source_dir.rglob('*.fits.gz')) - inferred_map = file_dset.create_new_map(inferred, upscale_factor, args.add_noise, model_name, config_data, padding) - inferred_map.save(output_file + '_HR.fits', overwrite=True) + destination_dir.mkdir(exist_ok=True, parents=True) - if args.plot: - plot_magnetogram(inferred_map, output_file + '_HR.png') + for file in source_files: + logger.info(f'Processing {file}') + out_path = destination_dir / (file.stem + '_HR.fits') - del inferred_map + if out_path.exists() and not overwrite: + logger.info(f'{file} already exists') + else: + convert(file, out_path, config_data) diff --git a/requirements.txt b/requirements.txt index 5a3baad..d3e54b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,5 @@ torch==1.1.0 torchvision==0.3.0 tensorboard==1.14.0 tensorboardX==1.8 -gitpython==3.1.0 -tensorflow==1.15.2 scikit-learn==0.21.3 gcsfs==0.2.3 \ No newline at end of file diff --git a/source/data_utils.py b/source/data_utils.py index 97b7adb..80c9ef8 100644 --- a/source/data_utils.py +++ b/source/data_utils.py @@ -1,17 +1,12 @@ import numpy as np -import math -import datetime from astropy import units as u -from sklearn.feature_extraction import image from astropy.coordinates import SkyCoord -from sunpy.map import Map from astropy.io import fits -from astropy.time import Time -from astropy.coordinates import solar_system_ephemeris, EarthLocation, get_body - - -import matplotlib.pyplot as plt +from matplotlib import pyplot as plt +from sklearn.feature_extraction import image +from sunpy.cm import cm as scm +from sunpy.map import Map from source.utils import disable_warnings, get_logger @@ -19,22 +14,21 @@ logger = get_logger(__name__) -def map_prep(file, instrument, *keyward_args): +def map_prep(file, instrument): """ - Return a processed hmi magnetogram and path + Return a processed magnetogram Parameters ---------- - file : file desctiptor - instrument: string - - keyward_args : + file : str + File to prep + instrument : str + Instrument name Returns - - tuple : preped map and filepath ------- - + sunpy.map.Map + Preped map and filepath """ # Open fits file as HUDL and fix header @@ -45,172 +39,54 @@ def map_prep(file, instrument, *keyward_args): if len(hdul) == 2: sun_map = Map(hdul[1].data, hdul[1].header) - - elif len(hdul) == 1: - if instrument == 'mdi': - - header = hdul[0].header - if header['SOLAR_P0']: - header['RSUN_OBS'] = header['OBS_R0'] - header['RSUN_REF'] = 696000000 - header['CROTA2'] = -header['SOLAR_P0'] - header['CRVAL1'] = 0.000000 - header['CRVAL2'] = 0.000000 - header['CUNIT1'] = 'arcsec' - header['CUNIT2'] = 'arcsec' - header['DSUN_OBS'] = header['OBS_DIST'] - header['DSUN_REF'] = 1 - - try: - header.pop('SOLAR_P0') - header.pop('OBS_DIST') - header.pop('OBS_R0') - except: - pass - - data = hdul[0].data - - if instrument == 'gong': - - header = hdul[0].header - if len(header['DATE-OBS'])<22: - header['RSUN_OBS'] = header['RADIUS'] * 180 / np.pi * 60 * 60 - header['RSUN_REF'] = 696000000 - header['CROTA2'] = 0 - header['CUNIT1'] = 'arcsec' - header['CUNIT2'] = 'arcsec' - header['DSUN_OBS'] = header['DISTANCE'] * 149597870691 - header['DSUN_REF'] = 149597870691 - header['cdelt1'] = 2.5534 - header['cdelt2'] = 2.5534 - - header['CTYPE1'] = 'HPLN-TAN' - header['CTYPE2'] = 'HPLT-TAN' - - - date = header['DATE-OBS'] - header['DATE-OBS'] = date[0:4] + '-' + date[5:7] + '-' + date[8:10] + 'T' + header['TIME-OBS'][0:11] - - data = hdul[0].data - - if instrument == 'spmg': - header = hdul[0].header - header['cunit1'] = 'arcsec' - header['cunit2'] = 'arcsec' - header['CDELT1'] = header['CDELT1A'] - header['CDELT2'] = header['CDELT2A'] - header['CRVAL1'] = 0 - header['CRVAL2'] = 0 - header['RSUN_OBS'] = header['EPH_R0 '] - header['CROTA2'] = 0 - header['CRPIX1'] = header['CRPIX1A'] - header['CRPIX2'] = header['CRPIX2A'] - header['PC2_1'] = 0 - header['PC1_2'] = 0 - header['RSUN_REF'] = 696000000 - - # Adding distance to header - t = Time(header['DATE-OBS']) - loc = EarthLocation.of_site('kpno') - with solar_system_ephemeris.set('builtin'): - sun = get_body('sun', t, loc) - header['DSUN_OBS'] = sun.distance.to('m').value - header['DSUN_REF'] = 149597870691 - - # selecting right layer for data - data = hdul[0].data[5, :, :] - - if instrument == 'kp512': - header = hdul[0].header - header['cunit1'] = 'arcsec' - header['cunit2'] = 'arcsec' - header['CDELT1'] = header['CDELT1A'] - header['CDELT2'] = header['CDELT2A'] - header['CRVAL1'] = 0 - header['CRVAL2'] = 0 - header['RSUN_OBS'] = header['EPH_R0 '] - header['CROTA2'] = 0 - header['CRPIX1'] = header['CRPIX1A'] - header['CRPIX2'] = header['CRPIX2A'] - header['PC2_1'] = 0 - header['PC1_2'] = 0 + data = sun_map.data.astype('>f4') + header = sun_map.meta + else: + data = hdul[0].data + header = hdul[0].header + + if instrument == 'mdi': + if 'SOLAR_P0' in header: + header['RSUN_OBS'] = header['OBS_R0'] header['RSUN_REF'] = 696000000 - - # Adding distance to header - t = Time(header['DATE-OBS']) - loc = EarthLocation.of_site('kpno') - with solar_system_ephemeris.set('builtin'): - sun = get_body('sun', t, loc) - header['DSUN_OBS'] = sun.distance.to('m').value - header['DSUN_REF'] = 149597870691 - - # selecting right layer for data - data = hdul[0].data[2, :, :] - - if instrument == 'mwo': - - file_name = file.name - - # Deconstruct Name to assess date - tmpPos = file_name.rfind('_') - - year = int(file_name[tmpPos - 6:tmpPos - 4]) - - # Adding century - if year < 1960: - year += 2000 - else: - year += 1900 - - month = int(file_name[tmpPos - 4:tmpPos - 2]) - day = int(file_name[tmpPos - 2:tmpPos]) - hr = int(file_name[tmpPos + 1:tmpPos + 3]) - 1 - mn = int(file_name[tmpPos + 3:tmpPos + 5]) - sc = 0 - - # Fix Times - if mn > 59: - mn = mn - 60 - hr = hr + 1 - - # Assemble date - if hr > 23: - tmpDate = datetime.datetime(year, month, day, hr - 24, mn, - sc) + datetime.timedelta(days=1) - else: - tmpDate = datetime.datetime(year, month, day, hr, mn, sc) - - header = hdul[0].header + header['CROTA2'] = -header['SOLAR_P0'] + header['CRVAL1'] = 0.000000 + header['CRVAL2'] = 0.000000 header['CUNIT1'] = 'arcsec' header['CUNIT2'] = 'arcsec' - header['CDELT1'] = header['DXB_IMG'] - header['CDELT2'] = header['DYB_IMG'] - header['CRVAL1'] = 0.0 - header['CRVAL2'] = 0.0 - header['RSUN_OBS'] = (header['R0']) * header['DXB_IMG'] - header['CROTA2'] = 0.0 - header['CRPIX1'] = header['X0'] - 0.5 - header['CRPIX2'] = header['Y0'] - 0.5 - header['T_OBS'] = tmpDate.strftime('%Y-%m-%dT%H-%M:00.0') - header['DATE-OBS'] = tmpDate.strftime('%Y-%m-%dT%H:%M:00.0') - header['DATE_OBS'] = tmpDate.strftime('%Y-%m-%dT%H:%M:00.0') + header['DSUN_OBS'] = header['OBS_DIST'] + header['DSUN_REF'] = 1 + + if 'DSUN_REF' not in header: + header['DSUN_REF'] = u.au.to('m') + + try: + header.pop('SOLAR_P0') + header.pop('OBS_DIST') + header.pop('OBS_R0') + except KeyError: + pass + + elif instrument == 'gong': + if len(header['DATE-OBS']) < 22: + header['RSUN_OBS'] = header['RADIUS'] * 180 / np.pi * 60 * 60 header['RSUN_REF'] = 696000000 + header['CROTA2'] = 0 + header['CUNIT1'] = 'arcsec' + header['CUNIT2'] = 'arcsec' + header['DSUN_OBS'] = header['DISTANCE'] * 149597870691 + header['DSUN_REF'] = 149597870691 + header['cdelt1'] = 2.5534 + header['cdelt2'] = 2.5534 + header['CTYPE1'] = 'HPLN-TAN' header['CTYPE2'] = 'HPLT-TAN' - header['RSUN_REF'] = 696000000 - # Adding distance to header - t = Time(header['DATE-OBS'], format='isot') - loc = EarthLocation.of_site('mwo') - with solar_system_ephemeris.set('builtin'): - sun = get_body('sun', t, loc) - header['DSUN_OBS'] = sun.distance.to('m').value - header['DSUN_REF'] = 149597870691 - - # selecting right layer for data - data = hdul[0].data + date = header['DATE-OBS'] + header['DATE-OBS'] = date[0:4] + '-' + date[5:7] + '-' + date[8:10] + 'T' +\ + header['TIME-OBS'][0:11] - sun_map = Map(data, header) + sun_map = Map(data, header) return sun_map @@ -220,30 +96,37 @@ def scale_rotate(amap, target_scale=0.504273, target_factor=0): Parameters ---------- - amap + amap : sunpy.map.Map + Input map + target_factor : float + + target_scale : int + Returns ------- - + sunpy.map.Map + Scaled and rotated map """ scalex = amap.meta['cdelt1'] scaley = amap.meta['cdelt2'] + if scalex != scaley: + raise ValueError('Square pixels expected') # Calculate target factor if not provided if target_factor == 0: target_factor = np.round(scalex / target_scale) ratio_plate = target_factor * target_scale / scalex - # logger.info(np.round(scalex / target_scale) / scalex * target_scale) + logger.debug(np.round(scalex / target_scale) / scalex * target_scale) ratio_dist = amap.meta['dsun_obs'] / amap.meta['dsun_ref'] - # logger.info(ratio_dist) + logger.debug(ratio_dist) # Pad image, if necessary new_shape = int(4096/target_factor) # Reform map to new size if original shape is too small - if new_shape > amap.data.shape[0]: new_fov = np.zeros((new_shape, new_shape)) * np.nan @@ -269,7 +152,7 @@ def scale_rotate(amap, target_scale=0.504273, target_factor=0): x_scale = ((rot_map.scale.axis1 * amap.dimensions.x) / 2) y_scale = ((rot_map.scale.axis2 * amap.dimensions.y) / 2) - # logger.info(f'x-scale {x_scale}, y-scale {y_scale}') + logger.debug(f'x-scale {x_scale}, y-scale {y_scale}') if x_scale != y_scale: logger.error(f'x-scale: {x_scale} and y-scale {y_scale} do not match') @@ -291,44 +174,52 @@ def scale_rotate(amap, target_scale=0.504273, target_factor=0): crop_map.meta['xscale'] = target_factor * target_scale crop_map.meta['yscale'] = target_factor * target_scale - return crop_map -def get_patch(amap, size): +def get_patches(amap, size): """ - create patches of dimension size * size with a defined stride. + create patches of dimension size * size with a stride of size. + Since stride is equals to size, there is no overlap Parameters ---------- - amap: sunpy map + amap : sunpy.map.Map + Input map to create tiles from - size: integer - size of each patch + size : int + Size of each patch Returns ------- - numpy array [num_patches, num_channel, size, size] - channels are magnetic field, radial distance relative to radius + numpy.array (num_patches, num_channel, size, size) + Channels are magnetic field, radial distance relative to radius """ array_radius = get_array_radius(amap) patches = image.extract_patches(amap.data, (size, size), extraction_step=size) - patches = patches.reshape([-1] + list((size, size))) + patches = patches.reshape(-1, size, size) patches_r = image.extract_patches(array_radius, (size, size), extraction_step=size) - patches_r = patches_r.reshape([-1] + list((size, size))) + patches_r = patches_r.reshape(-1, size, size) return np.stack([patches, patches_r], axis=1) def get_array_radius(amap): """ - Compute an array with the radial coordinate for each pixel - :param amap: - :return: (W, H) array + Compute an array with the radial coordinate for each pixel. + + Parameters + ---------- + amap : sunpy.map.Map + Input map + Returns + ------- + numpy.ndarray + Array full of radius values """ x, y = np.meshgrid(*[np.arange(v.value) for v in amap.dimensions]) * u.pixel hpc_coords = amap.pixel_to_world(x, y) @@ -337,16 +228,23 @@ def get_array_radius(amap): return array_radius -def get_image_from_array(list_patches): +def get_image_from_patches(list_patches): """ Reconstruct from a list of patches the full disk image - :param list_patches: - :return: + + Parameters + ---------- + list_patches : list of patches produced from `get_patches` + + Returns + ------- + numpy.ndarray + Reconstructed image """ out = np.array(list_patches) out_r = out.reshape(out.shape[0] * out.shape[1], out.shape[2], out.shape[3]) - size = int(math.sqrt(out_r.shape[0])) + size = int(np.sqrt(out_r.shape[0])) out_array = np.array_split(out_r, size, axis=0) out_array = np.concatenate(out_array, axis=1) out_array = np.concatenate(out_array, axis=1) @@ -354,11 +252,24 @@ def get_image_from_array(list_patches): return out_array -def plot_magnetogram(amap, file, scale=1, vmin=-2000, vmax=2000, cmap=plt.cm.get_cmap('hmimag')): +def plot_magnetogram(amap, file, scale=1, vmin=-2000, vmax=2000, cmap=scm.hmimag): """ Plot magnetogram - :param amap: - :return: (W, H) array + + Parameters + ---------- + amap : sunpy.map.Map + Magnetogram map to plot + file : pathlib.Path + Filename to save plot as + scale : + + vmin : float + Min value for color map + vmax : float + Max value for color map + cmap : matplotlib.colors.Colormap + Colormap """ # Size definitions @@ -383,21 +294,20 @@ def plot_magnetogram(amap, file, scale=1, vmin=-2000, vmax=2000, cmap=plt.cm.get ppxx = pxx / fszh # Horizontal size of each panel in relative units ppxy = pxy / fszv # Vertical size of each panel in relative units ppadv = padv / fszv # Vertical padding in relative units - ppadv2 = padv2 / fszv # Vertical padding in relative units + # Never used so commented out + # ppadv2 = padv2 / fszv # Vertical padding in relative units ppadh = padh / fszh # Horizontal padding the edge of the figure in relative units - ppadh2 = padh2 / fszh # Horizontal padding between panels in relative units + # Never used so commented out + # ppadh2 = padh2 / fszh # Horizontal padding between panels in relative units - ## Start Figure + # Start Figure fig = plt.figure(figsize=(fszh / dpi, fszv / dpi), dpi=dpi) # ## Add Perihelion ax1 = fig.add_axes([ppadh + ppxx, ppadv + ppxy, ppxx, ppxy]) ax1.imshow(amap.data, vmin=vmin, vmax=vmax, cmap=cmap, origin='lower') ax1.set_axis_off() - ax1.text(0.99, 0.99, 'ML Output', horizontalalignment='right', verticalalignment='top', color='k', - transform=ax1.transAxes) + ax1.text(0.99, 0.99, 'ML Output', horizontalalignment='right', verticalalignment='top', + color='k', transform=ax1.transAxes) fig.savefig(file, bbox_inches='tight', dpi=dpi, pad_inches=0) - - - diff --git a/source/dataset.py b/source/dataset.py index a4c786d..e0e2d70 100644 --- a/source/dataset.py +++ b/source/dataset.py @@ -1,35 +1,51 @@ -import numpy as np - +from datetime import datetime -from torch.utils.data import Dataset -from torch import from_numpy +import numpy as np from sunpy.map import Map -from datetime import datetime +from torch import from_numpy +from torch.utils.data import Dataset -from source.data_utils import get_patch, get_array_radius, map_prep, scale_rotate +from source.data_utils import get_patches, get_array_radius, map_prep, scale_rotate class FitsFileDataset(Dataset): """ Construct a dataset of patches from a fits file """ - def __init__(self, file, size, norm, instrument, rescale, upscale_factor): - map = map_prep(file, instrument) - map.data[:] = map.data[:]/norm + def __init__(self, file, size, config_data): + + self.instrument = config_data['instrument'] + self.norm = config_data['data']['normalisation'] + self.rescale = config_data['data']['rescale'] + self.upscale_factor = config_data['net']['upscale_factor'] + self.padding = config_data['data']['padding'] + self.loss = config_data['loss'] + self.add_noise = config_data['cli']['add_noise'] + + amap = map_prep(file, self.instrument) + amap.data[:] = amap.data[:] / self.norm # Detecting need for rescale - if rescale and np.abs(1 - (map.meta['cdelt1']/0.504273)/upscale_factor) > 0.01: - map = scale_rotate(map, target_factor=upscale_factor) + if self.rescale and np.abs(1 - (amap.meta['cdelt1']/0.504273) / self.upscale_factor) > 0.01: + amap = scale_rotate(amap, target_factor=self.upscale_factor) - self.data = get_patch(map, size) - self.map = map + self.data = get_patches(amap, size) + self.map = amap def __getitem__(self, idx): """ - Create torch tensor from patch with id idx - :param idx: - :return: tensor (W, H, 2) + Get a patch + + Parameters + ---------- + idx : int + Index + + Returns + ------- + torch.tensor + patch """ patch = self.data[idx, ...] patch[patch != patch] = 0 @@ -39,26 +55,33 @@ def __getitem__(self, idx): def __len__(self): return self.data.shape[0] - def create_new_map(self, new_data, scale_factor, add_noise, model_name, config_data, padding): + def create_new_map(self, new_data, model_name): """ Adjust header to match upscaling factor and add new keywords - :return: + + Parameters + ---------- + new_data : + Superresolved map data + model_name : str + Name of the model use """ new_meta = self.map.meta.copy() # Changing scale and center - new_meta['crpix1'] = new_meta['crpix1'] - self.map.data.shape[0] / 2 + self.map.data.shape[0] * scale_factor / 2 - new_meta['crpix2'] = new_meta['crpix2'] - self.map.data.shape[1] / 2 + self.map.data.shape[1] * scale_factor / 2 - new_meta['cdelt1'] = new_meta['cdelt1'] / scale_factor - new_meta['cdelt2'] = new_meta['cdelt2'] / scale_factor + new_meta['crpix1'] = (new_meta['crpix1'] - self.map.data.shape[0] / 2 + + self.map.data.shape[0] * self.upscale_factor / 2) + new_meta['crpix2'] = (new_meta['crpix2'] - self.map.data.shape[1] / 2 + + self.map.data.shape[1] * self.upscale_factor / 2) + new_meta['cdelt1'] = new_meta['cdelt1'] / self.upscale_factor + new_meta['cdelt2'] = new_meta['cdelt2'] / self.upscale_factor try: - new_meta['im_scale'] = new_meta['im_scale'] / scale_factor - new_meta['fd_scale'] = new_meta['im_scale'] / scale_factor - new_meta['xscale'] = new_meta['xscale'] / scale_factor - new_meta['yscale'] = new_meta['yscale'] / scale_factor - + new_meta['im_scale'] = new_meta['im_scale'] / self.upscale_factor + new_meta['fd_scale'] = new_meta['im_scale'] / self.upscale_factor + new_meta['xscale'] = new_meta['xscale'] / self.upscale_factor + new_meta['yscale'] = new_meta['yscale'] / self.upscale_factor except: pass @@ -68,36 +91,32 @@ def create_new_map(self, new_data, scale_factor, add_noise, model_name, config_ # Changing data info new_meta['datamin'] = np.nanmin(self.map.data) new_meta['datamax'] = np.nanmax(self.map.data) - new_meta['data_rms'] = np.sqrt(np.nanmean(self.map.data**2)) + new_meta['data_rms'] = np.sqrt(np.nanmean(self.map.data ** 2)) new_meta['datamean'] = np.nanmean(self.map.data) new_meta['datamedn'] = np.nanmedian(self.map.data) new_meta['dataskew'] = np.nanmedian(self.map.data) - # Add keywords related to conversion try: new_meta['instrume'] = new_meta['instrume'] + '-2HMI_HR' - except: + except KeyError: new_meta['instrume'] = new_meta['telescop'] + '-2HMI_HR' new_meta['hrkey1'] = '---------------- HR ML Keywords Section ----------------' new_meta['date-ml'] = str(datetime.utcnow()) new_meta['nn-model'] = model_name - new_meta['loss'] = ', '.join('{!s}={!r}'.format(key, val) for (key, val) in config_data['loss'].items()) + new_meta['loss'] = ', '.join( + '{!s}={!r}'.format(key, val) for (key, val) in self.loss.items()) new_meta['conv_doi'] = 'https://doi.org/10.5281/zenodo.3750372' new_meta['hrkey2'] = '---------------- HR ML Keywords Section ----------------' new_map = Map(new_data, new_meta) - if add_noise: - noise = np.random.normal(loc=0.0, scale=add_noise, size=new_map.data.shape) + if self.add_noise: + noise = np.random.normal(loc=0.0, scale=self.add_noise, size=new_map.data.shape) new_map.data[:] = new_map.data[:] + noise[:] array_radius = get_array_radius(new_map) - new_map.data[array_radius >= 1] = padding + new_map.data[array_radius >= 1] = self.padding return new_map - - - - diff --git a/source/load.py b/source/load.py deleted file mode 100644 index 2417357..0000000 --- a/source/load.py +++ /dev/null @@ -1,43 +0,0 @@ -import os - -import torch - -from google.cloud import storage -from source.utils import get_logger - -logger = get_logger(__name__) - -def load_from_google_cloud(run_name, epoch, model): - """ - Construct a torch model from pe-trained model run_name stored on goolge cloud - :param run_name: string - :param epoch: int - :param model: torch model - :return: torch model with pre trained parameters - """ - - gcs_storage_client = storage.Client() - - bucket = gcs_storage_client.bucket('fdl-mag-experiments') - blob = bucket.blob(f'checkpoints/{run_name}/epoch_{epoch}') - - if not os.path.exists(f'checkpoints/{run_name}'): - logger.info(f'Creating checkpoint folder: checkpoints/{run_name}') - os.makedirs(f'checkpoints/{run_name}', exist_ok=True) - if not os.path.exists(f'checkpoints/{run_name}/epoch_{epoch}'): - logger.info(f'Downloading checkpoint: {epoch}') - blob.download_to_filename(f'checkpoints/{run_name}/epoch_{epoch}') - - checkpoint = torch.load(f'checkpoints/{run_name}/epoch_{epoch}', map_location='cpu') - logger.info(f'Loading Model: fdl-mag-experiments/checkpoints/{run_name}/epoch_{epoch}') - - if list(checkpoint['model_state_dict'].keys())[0].split('.')[0] == 'module': - state_dict = {} - for key, value in checkpoint['model_state_dict'].items(): - state_dict['.'.join(key.split('.')[1:])] = value - - model.load_state_dict(state_dict) - else: - model.load_state_dict(checkpoint['model_state_dict']) - - return model diff --git a/source/utils.py b/source/utils.py index ec65e46..2cea80d 100644 --- a/source/utils.py +++ b/source/utils.py @@ -3,22 +3,21 @@ import warnings - def get_logger(name): """ - Return a logger for current module + Return a logger for current module. + Returns ------- - - logger : logger instance - + logging.Logger + Logger instance """ logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s %(name)s: %(message)s", datefmt="%Y-%m-%d - %H:%M:%S") console = logging.StreamHandler(sys.stdout) - console.setLevel(logging.DEBUG) + console.setLevel(logging.INFO) console.setFormatter(formatter) logfile = logging.FileHandler('run.log', 'w') @@ -30,6 +29,7 @@ def get_logger(name): return logger + def disable_warnings(): """ Disable printing of warnings @@ -39,6 +39,3 @@ def disable_warnings(): None """ warnings.simplefilter("ignore") - - - \ No newline at end of file