Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions src/RegisterWorkerShell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module RegisterWorkerShell
using SimpleTraits, ImageAxes, ImageMetadata, Distributed, SharedArrays
using AxisArrays: AxisArray, Axis

export AbstractWorker, AnyValue, ArrayDecl, close!, init!, maybe_sharedarray, monitor, monitor!, worker, workerpid, getindex_t
export AbstractWorker, AnyValue, ArrayDecl, close!, init!, maybe_sharedarray, monitor, monitor_thread, monitor!, worker, workerpid, getindex_t
export load_mm_package

"""
Expand Down Expand Up @@ -36,7 +36,7 @@ subtypes. The exported operations are:
- `init!` and `close!`: functions you may specialize if your algorithm
needs to initialize or clean up resources
- `worker`: perform registration on an image
- `workerpid`: extract the process-id for a given worker
- `workertid`: extract the thread-id
"""
RegisterWorkerShell

Expand All @@ -58,26 +58,20 @@ The worker algorithm should call `monitor!(mon, algorithm)` to copy
the values into `mon`, and `monitor!(mon, :var3, var3)` for an
internal variable `var3` that is not taken from `algorithm`. See
`monitor!` for more detail.

An important detail is that if `workerpid(algorithm) ≠ myid()`, then any
requested `AbstractArray` fields in `algorithm` will be turned into
`SharedArray`s for `mon`. This reduces the cost of communication
between the worker and driver processes.
"""
function monitor(algorithm::AbstractWorker, fields::Union{NTuple{N,Symbol},Vector{Symbol}}, morevars::Dict{Symbol} = Dict{Symbol,Any}()) where N
pid = workerpid(algorithm)
mon = Dict{Symbol,Any}()
for f in fields
isdefined(algorithm, f) || continue
mon[f] = maybe_sharedarray(getfield(algorithm, f), pid)
mon[f] = getfield(algorithm, f)
end
for (k,v) in morevars
mon[k] = maybe_sharedarray(v, pid)
mon[k] = v
end
mon
end

monitor(algorithm::Vector{W}, fields, morevars::Dict{Symbol} = Dict{Symbol,Any}()) where {W<:AbstractWorker} = map(alg->monitor(alg, fields, morevars), algorithm)
monitor(algorithms::Vector{W}, fields, morevars::Dict{Symbol} = Dict{Symbol,Any}()) where {W<:AbstractWorker} =
map(alg->monitor(alg, fields, morevars), algorithms) # for multi-thread

"""
`monitor!(mon, algorithm)` updates `mon` with the current values of
Expand Down Expand Up @@ -150,12 +144,12 @@ worker(algorithm::AbstractWorker, img, tindex, mon) = error("Worker modules must
worker(rr::RemoteChannel, img, tindex, mon) = worker(fetch(rr), img, tindex, mon)

"""
`workerpid(algorithm)` extracts the `pid` associated with the worker
`workertid(algorithm)` extracts the `workertid` associated with the thread
that will be assigned tasks for `algorithm`. All `AbstractWorker`
subtypes should include a `workerpid` field, or overload this function
subtypes should include a `workertid` field, or overload this function
to return myid().
"""
workerpid(w::AbstractWorker) = w.workerpid
workertid(w::AbstractWorker) = w.workertid

"""
`load_mm_package(dev)` loads appropriate mismatch module conditioned on
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
using RegisterWorkerShell, Test

@test 1+1 == 2