Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing afbarnard/go-lbfgsb#4 #7

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 62 additions & 20 deletions lbfgsb.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"math"
"reflect"
"sync"
"unsafe"
)

Expand Down Expand Up @@ -292,14 +293,17 @@ func (lbfgsb *Lbfgsb) Minimize(
upperBounds := makeCCopySlice_Float(lbfgsb.upperBounds, dim)

// Set up callbacks for function, gradient, and logging
callbackData_c := unsafe.Pointer(
&callbackData{objective: objective})
cId := registerCallback(objective)
defer unregisterCallback(cId)
callbackData_c := unsafe.Pointer(cId)
var doLogging_c C.int // false
var logFunctionCallbackData_c unsafe.Pointer // null
var loggerId uintptr
if lbfgsb.logger != nil {
doLogging_c = C.int(1) // true
logFunctionCallbackData_c = unsafe.Pointer(
&logCallbackData{logger: lbfgsb.logger})
loggerId = registerCallback(lbfgsb.logger)
defer unregisterCallback(loggerId)
logFunctionCallbackData_c = unsafe.Pointer(loggerId)
}

// Allocate arrays for return value
Expand Down Expand Up @@ -379,19 +383,57 @@ func (lbfgsb *Lbfgsb) OptimizationStatistics() OptimizationStatistics {
return lbfgsb.statistics
}

// callbackData is a container for the actual objective function and
// callbackFunctions is a container for the actual objective functions and
// related data.
type callbackData struct {
objective FunctionWithGradient
var callbackFunctions = make(map[uintptr]interface{})

// callbackIndex stores an index to use for new callback function.
var callbackIndex uintptr

// callbackMutex is a mutex preventing simultanious access to callback
// and callbackIds.
var callbackMutex sync.Mutex

// registerCallback registers a new callback and returns its' index
// (>=1).
func registerCallback(f interface{}) uintptr {
callbackMutex.Lock()
defer callbackMutex.Unlock()
// We always increment callbackIndex to have more or less
// unique ids. This way it is easier to debug problems with
// reusing unregistered ids.
callbackIndex++
startIndex := callbackIndex
for callbackIndex == 0 || callbackFunctions[callbackIndex] != nil {
// Find the first free non-zero index.
callbackIndex++
// If the map is full, i.e. all non-zero uintptrs were
// used, we do not want to loop infinitely. We check
// if we already encountered the starting index. If
// so, we panic. In practice this is very unlikely to
// have this kind of problem since all the objects are
// unregistered at the end of the function call.
if callbackIndex == startIndex {
panic("no more space in the map to store a callback function")
}
}
callbackFunctions[callbackIndex] = f
return callbackIndex
}

// logCallbackData is a container for the logging function. It might be
// tempting to just use a function pointer instead of this container,
// but passing a function pointer to void* in C possibly truncates the
// address because void* is for data pointers only and function pointers
// may be wider.
type logCallbackData struct {
logger OptimizationIterationLogger
// lookupCallback returns a callback function given an index.
func lookupCallback(i uintptr) interface{} {
callbackMutex.Lock()
defer callbackMutex.Unlock()
return callbackFunctions[i]
}

// unregisterCallback unregisters a callback by removing it from the
// callbackFunctions map.
func unregisterCallback(i uintptr) {
callbackMutex.Lock()
defer callbackMutex.Unlock()
delete(callbackFunctions, i)
}

// go_objective_function_callback is an adapter between the C callback
Expand All @@ -411,11 +453,11 @@ func go_objective_function_callback(
// Convert inputs
dim := int(dim_c)
wrapCArrayAsGoSlice_Float64(point_c, dim, &point)
cbData := (*callbackData)(callbackData_c)
objective := lookupCallback(uintptr(callbackData_c)).(FunctionWithGradient)

// Evaluate the objective function. Let panics propagate through
// C/Fortran.
value := cbData.objective.EvaluateFunction(point)
value := objective.EvaluateFunction(point)

// Convert outputs
*value_c = C.double(value)
Expand All @@ -442,11 +484,11 @@ func go_objective_gradient_callback(
// Convert inputs
dim := int(dim_c)
wrapCArrayAsGoSlice_Float64(point_c, dim, &point)
cbData := (*callbackData)(callbackData_c)
objective := lookupCallback(uintptr(callbackData_c)).(FunctionWithGradient)

// Evaluate the gradient of the objective function. Let panics
// propagate through C/Fortran.
gradRet = cbData.objective.EvaluateGradient(point)
gradRet = objective.EvaluateGradient(point)

// Convert outputs
wrapCArrayAsGoSlice_Float64(gradient_c, dim, &gradient)
Expand Down Expand Up @@ -478,11 +520,11 @@ func go_log_function_callback(
wrapCArrayAsGoSlice_Float64(g_c, dim, &g)

// Get the logging function from the callback data
cbData := (*logCallbackData)(logCallbackData_c)
logger := lookupCallback(uintptr(logCallbackData_c)).(OptimizationIterationLogger)

// Call the logging function. Let panics propagate through
// C/Fortran.
cbData.logger(
logger(
&OptimizationIterationInformation{
Iteration: int(iteration_c),
FEvals: int(fgEvals_c),
Expand Down