Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-Barth committed May 3, 2024
2 parents 7c1dfbb + 4ca04cf commit cec71e4
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 67 deletions.
14 changes: 7 additions & 7 deletions .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,27 @@ assignees: ''

---


__Please provide all information reuested here__

**Describe the bug**
A clear and concise description of what the bug is.

**To Reproduce**
Please provide a minimal julia code example which reproduces the behavior (bug, performance regression, ...).
An example is reproducible if it includes all steps allow somebody else to get the same behavior as you are observing.
If your example needs a data set, reduce the size of the data so that it can be easily shared. If your dataset is private,
try to reproduce the error with a public available dataset or random data. Do not send large files via emails.

Please provide a minimal Julia code example which reproduces the behavior (bug, performance regression, ...).
An example is reproducible if it includes all steps allowing somebody else to get the same behavior as you are observing.
If your example needs a dataset, reduce the size of the data so that it can be easily shared. If your dataset is private,
try to reproduce the error with a publicly available dataset or randomly generated data. Do not send large files via email, but provide a download link to the dataset in the issue report.

**Expected behavior**
A clear and concise description of what you expected to happen.

**Environment**
- operating system: [e.g. Ubuntu XX.YY]
- output of the julia command versioninfo()
- output of the julia command `versioninfo()`
- do you use the official binaries from https://julialang.org/downloads/ ?
- DINCAE version

** Full output **
**Full output**

In case of an error, please paste the *full* error message and stack trace.
2 changes: 1 addition & 1 deletion .github/workflows/Documenter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
name: Documentation
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- run: |
sudo apt update
sudo apt install python3-matplotlib
Expand Down
7 changes: 4 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ jobs:
arch:
- x64
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
Expand All @@ -42,6 +42,7 @@ jobs:
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v1
- uses: codecov/codecov-action@v4
with:
file: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
[![codecov.io](http://codecov.io/github/gher-uliege/DINCAE.jl/coverage.svg?branch=main)](http://codecov.io/github/gher-uliege/DINCAE.jl?branch=main)
[![documentation stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://gher-uliege.github.io/DINCAE.jl/stable/)
[![documentation dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://gher-uliege.github.io/DINCAE.jl/dev/)
[![DOI](https://zenodo.org/badge/193079989.svg)](https://zenodo.org/badge/latestdoi/193079989)

[![Issues](https://img.shields.io/github/issues-raw/gher-uliege/DINCAE.jl?style=plastic)](https://github.com/gher-uliege/DINCAE.jl/issues)
![Issues](https://img.shields.io/github/commit-activity/m/gher-uliege/DINCAE.jl)
Expand Down
9 changes: 4 additions & 5 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ The method is described in the following articles:
* Barth, A., Alvera-Azcárate, A., Ličer, M., & Beckers, J.-M. (2020). DINCAE 1.0: a convolutional neural network with error estimates to reconstruct sea surface temperature satellite observations. Geoscientific Model Development, 13(3), 1609–1622. https://doi.org/10.5194/gmd-13-1609-2020
* Barth, A., Alvera-Azcárate, A., Troupin, C., & Beckers, J.-M. (2022). DINCAE 2.0: multivariate convolutional neural network with error estimates to reconstruct sea surface temperature satellite and altimetry observations. Geoscientific Model Development, 15(5), 2183–2196. https://doi.org/10.5194/gmd-15-2183-2022

The neural network will be trained on the GPU. Note convolutional neural networks can require a lot of GPU memory depending on the domain size.
# So far, only NVIDIA GPUs are supported by the neural network framework `Knet.jl` used in DINCAE (beside training on the CPU but which is prohibitively slow).
[`Flux.jl`](https://github.com/FluxML/Flux.jl) supports NVIDIA GPUs as well as other brands (see https://fluxml.ai/Flux.jl/stable/gpu/ for details).
Training on the CPU can be performedi, but it is prohibitively slow.
The neural network will be trained on the GPU. Note convolutional neural networks can require a lot of GPU memory depending on the domain size.
[`Flux.jl`](https://github.com/FluxML/Flux.jl) supports NVIDIA GPUs as well as other vendors (see https://fluxml.ai/Flux.jl/stable/gpu/ for details).
Training on the CPU can be performeded, but it is prohibitively slow.

## User API

Expand Down Expand Up @@ -73,5 +72,5 @@ using DINCAE

`DINCAE.jl` depends on `Flux.jl` and `CUDA.jl`, which will automatically be installed.
If you have some problems installing these package you might consult the
[documentation of `Flux.jl`](http://fluxml.ai/Flux.jl/stable/#Installation) or
[documentation of `Flux.jl`](http://fluxml.ai/Flux.jl/stable/#Installation) or
[`CUDA.jl`](https://cuda.juliagpu.org/stable/installation/overview/).
140 changes: 97 additions & 43 deletions src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ attributes:
SST:_FillValue = -9999.f ;
}
The the netCDF mask is 0 for invalid (e.g. land for an ocean application) and 1 for pixels (e.g. ocean).
"""
function load_gridded_nc(fname::AbstractString,varname::AbstractString; minfrac = 0.05)
function load_gridded_nc(fname::AbstractString,varname::AbstractString, errvarname; minfrac = 0.05, obs_err_std = 1f0)

@info "loading file $fname"
ds = Dataset(fname);
lon = nomissing(ds["lon"][:])
lat = nomissing(ds["lat"][:])
Expand All @@ -52,35 +56,55 @@ function load_gridded_nc(fname::AbstractString,varname::AbstractString; minfrac
println("mask: sea points ",sum(mask)," land points ",sum(.!mask))
end

close(ds)
missingmask = isnan.(data)
sz = size(data)

println("$varname data shape: $(format_size(sz)) data range: $(extrema(data[isfinite.(data)]))")

data4d = reshape(data,(sz[1],sz[2],1,sz[3]))
return lon,lat,time,data4d,missingmask,mask

if !isnothing(errvarname)
error4d = reshape(nomissing(ds[errvarname][:,:,:],Inf),(sz[1],sz[2],1,sz[3]))
else
@info("no error field provided using $obs_err_std as error standard dev.")
error4d = fill(obs_err_std,(sz[1],sz[2],1,sz[3]))
end

close(ds)
return lon,lat,time,data4d,error4d,missingmask,mask
end

"""
the first variable is for isoutput is none specified
"""
function load_gridded_nc(data::AbstractVector{NamedTuple{(:filename, :varname, :obs_err_std),T}}) where {T}
return load_gridded_nc([(d...,isoutput = i==1) for (i,d) in enumerate(data)])
function fill_default(data::AbstractVector{<:NamedTuple})
# the first variable is the output if isoutput is not specified
[(isoutput = i==1,
errvarname = nothing,
obs_err_std = 1f0,
jitter_std = 0.05,
ndims = 1,
d..., # overwrite all default value if provided
) for (i,d) in enumerate(data)]
end

function load_gridded_nc(data)
lon,lat,datatime,data_full1,missingmask,mask = load_gridded_nc(data[1].filename,data[1].varname);
d = fill_default(data)

lon,lat,datatime,data_full1,error_full1,missingmask,mask = load_gridded_nc(
d[1].filename,d[1].varname,d[1].errvarname);

sz = size(data_full1)
data_full = zeros(Float32,sz[1],sz[2],length(data),sz[4]);
error_full = zeros(Float32,sz[1],sz[2],length(data),sz[4]);

data_full[:,:,1,:] = data_full1;
error_full[:,:,1,:] = error_full1;

for i in 2:length(data)
lon_,lat_,datatime_,data_full[:,:,i,:],missingmask_,mask_ = load_gridded_nc(
data[i].filename,data[i].varname);
lon_,lat_,datatime_,data_full[:,:,i,:],error_full[:,:,i,:],missingmask_,mask_ = load_gridded_nc(
d[i].filename,d[i].varname,d[i].errvarname,
obs_err_std = d[i].obs_err_std);
end

return lon,lat,datatime,data_full,missingmask,mask
return lon,lat,datatime,data_full,error_full,missingmask,mask
end

function normalize2(data)
Expand All @@ -98,13 +122,13 @@ function load_aux_data(T,sz,auxdata_files)

for i = 1:length(auxdata_files)
NCDataset(auxdata_files[i].filename) do ds
data = nomissing(ds[auxdata_files[i].varname][:,:,:])
data = nomissing(Array(ds[auxdata_files[i].varname]))

data_std_err =
if isnothing(auxdata_files[i].errvarname)
ones(size(data))
else
nomissing(ds[auxdata_files[i].errvarname][:,:,:])
nomissing(Array(ds[auxdata_files[i].errvarname]))
end

@info("remove time mean from $(auxdata_files[i].filename)")
Expand All @@ -127,10 +151,11 @@ mutable struct NCData{T,N #=,TA=#}
data_full::Array{T,4}
missingmask::BitArray{3}
meandata::Array{T,3}
mask::BitMatrix
x::Array{T,5}
isoutput::Vector{Bool}
train::Bool
obs_err_std::Vector{T}
# obs_err_std::Vector{T}
jitter_std::Vector{T}
lon_scaled::Vector{T}
lat_scaled::Vector{T}
Expand Down Expand Up @@ -208,8 +233,10 @@ export sizey
"""
dd = NCData(lon,lat,time,data_full,missingmask,ndims;
train = false,
obs_err_std = 1.,
jitter_std = 0.05)
obs_err_std = fill(1.,size(data_full,3)),
jitter_std = fill(0.05,size(data_full,3)),
mask = trues(size(data_full)[1:2]),
)
Return a structure holding the data for training (`train = true`) or testing (`train = false`)
the neural network. `obs_err_std` is the error standard deviation of the
Expand All @@ -229,6 +256,7 @@ function NCData(lon,lat,time,data_full,missingmask,ndims;
cycle_periods = (365.25,), # days
time_origin = DateTime(1970,1,1),
remove_mean = true,
mask = trues(size(data_full)[1:2]),
direction_obs = nothing,
# auxdata = (),
)
Expand All @@ -244,6 +272,10 @@ function NCData(lon,lat,time,data_full,missingmask,ndims;
ndata = size(data_full,3)
ntime = size(data_full,4)

@info "size(obs_err_std)" size(obs_err_std)
@info "number of parameters " ndata
@info "number of time instances " ntime

time_cos = zeros(Float32,length(cycle_periods),length(time))
time_sin = zeros(Float32,length(cycle_periods),length(time))

Expand Down Expand Up @@ -276,10 +308,15 @@ function NCData(lon,lat,time,data_full,missingmask,ndims;
x[:,:,:,:,1] = replace(data,NaN => 0)
x[:,:,:,:,2] = (1 .- isnan.(data))

for i = 1:ndata
# inv. error variance
x[:,:,i,:,1] ./= obs_err_std[i]^2
x[:,:,i,:,2] ./= obs_err_std[i]^2
if Base.ndims(obs_err_std) == 1
for i = 1:ndata
# inv. error variance
x[:,:,i,:,1] ./= obs_err_std[i]^2
x[:,:,i,:,2] ./= obs_err_std[i]^2
end
else
x[:,:,:,:,1] ./= obs_err_std.^2
x[:,:,:,:,2] ./= obs_err_std.^2
end
# else
# # dimensions of x: lon, lat, parameter, time, 5
Expand All @@ -300,39 +337,47 @@ function NCData(lon,lat,time,data_full,missingmask,ndims;

N = (is3D ? 4 : 3)

NCData{Float32,N}(Float32.(lon),Float32.(lat),time,data_full,missingmask,meandata[:,:,:,1],x,
isoutput,
train,
Float32.(obs_err_std),
Float32.(jitter_std),
lon_scaled,
lat_scaled,
time_cos,
time_sin,
ntime_win,
# auxdata,
direction_obs_,
output_ndims,
ndims,
)
NCData{Float32,N}(
Float32.(lon),Float32.(lat),
time,
data_full,
missingmask,
meandata[:,:,:,1],
mask,
x,
isoutput,
train,
# Float32.(obs_err_std),
Float32.(jitter_std),
lon_scaled,
lat_scaled,
time_cos,
time_sin,
ntime_win,
# auxdata,
direction_obs_,
output_ndims,
ndims,
)
end


getp(x,sym,default) = (hasproperty(x, sym) ? getproperty(x,sym) : default)

function NCData(data; kwargs...)
lon,lat,datatime,data_full,missingmask,mask = DINCAE.load_gridded_nc(data)
lon,lat,datatime,data_full,error_full,missingmask,mask = load_gridded_nc(data)

default_jitter_std = 0.05

jitter_std = [getp(d,:jitter_std,default_jitter_std) for d in data]
ndims = [getp(d,:ndims,1) for d in data]

return DINCAE.NCData(lon,lat,datatime,data_full,missingmask,ndims;
obs_err_std = [d.obs_err_std for d in data],
jitter_std = jitter_std,
isoutput = [d.isoutput for d in data],
kwargs...)
return NCData(lon,lat,datatime,data_full,missingmask,ndims;
obs_err_std = error_full,
jitter_std = jitter_std,
isoutput = [d.isoutput for d in data],
mask = mask,
kwargs...)

end

Expand Down Expand Up @@ -554,7 +599,10 @@ function getobs!(dd::NCData,data,index::Int)
return data
end

function savesample(ds,varnames,xrec,meandata,ii,offset; output_ndims = 1)
function savesample(ds,varnames,xrec,meandata,ii,offset;
output_ndims = 1,
mask = nothing)

fill_value = -9999.

function accumulate!(var,index,slice,count)
Expand Down Expand Up @@ -592,6 +640,7 @@ function savesample(ds,varnames,xrec,meandata,ii,offset; output_ndims = 1)
end
end

# typically the batch size
nmax = size(xrec,4)

if output_ndims == 1
Expand All @@ -608,6 +657,11 @@ function savesample(ds,varnames,xrec,meandata,ii,offset; output_ndims = 1)
batch_sigma_rec[isnan.(recdata)] .= NaN

for n in 1:nmax
if !isnothing(mask)
view(recdata,:,:,n)[.!mask] .= NaN
view(batch_sigma_rec,:,:,n)[.!mask] .= NaN
end

accumulate!(nc_batch_m_rec.var,n+offset,recdata[:,:,n],count)
accumulate!(nc_batch_sigma_rec.var,n+offset,batch_sigma_rec[:,:,n],count)
end
Expand Down
Loading

0 comments on commit cec71e4

Please sign in to comment.