diff --git a/lbfgsb.go b/lbfgsb.go index 53c14fe..2aee9ce 100644 --- a/lbfgsb.go +++ b/lbfgsb.go @@ -19,6 +19,7 @@ import ( "fmt" "math" "reflect" + "sync" "unsafe" ) @@ -292,14 +293,17 @@ func (lbfgsb *Lbfgsb) Minimize( upperBounds := makeCCopySlice_Float(lbfgsb.upperBounds, dim) // Set up callbacks for function, gradient, and logging - callbackData_c := unsafe.Pointer( - &callbackData{objective: objective}) + cId := registerCallback(objective) + defer unregisterCallback(cId) + callbackData_c := unsafe.Pointer(cId) var doLogging_c C.int // false var logFunctionCallbackData_c unsafe.Pointer // null + var loggerId uintptr if lbfgsb.logger != nil { doLogging_c = C.int(1) // true - logFunctionCallbackData_c = unsafe.Pointer( - &logCallbackData{logger: lbfgsb.logger}) + loggerId = registerCallback(lbfgsb.logger) + defer unregisterCallback(loggerId) + logFunctionCallbackData_c = unsafe.Pointer(loggerId) } // Allocate arrays for return value @@ -379,19 +383,57 @@ func (lbfgsb *Lbfgsb) OptimizationStatistics() OptimizationStatistics { return lbfgsb.statistics } -// callbackData is a container for the actual objective function and +// callbackFunctions is a container for the actual objective functions and // related data. -type callbackData struct { - objective FunctionWithGradient +var callbackFunctions = make(map[uintptr]interface{}) + +// callbackIndex stores an index to use for new callback function. +var callbackIndex uintptr + +// callbackMutex is a mutex preventing simultanious access to callback +// and callbackIds. +var callbackMutex sync.Mutex + +// registerCallback registers a new callback and returns its' index +// (>=1). +func registerCallback(f interface{}) uintptr { + callbackMutex.Lock() + defer callbackMutex.Unlock() + // We always increment callbackIndex to have more or less + // unique ids. This way it is easier to debug problems with + // reusing unregistered ids. + callbackIndex++ + startIndex := callbackIndex + for callbackIndex == 0 || callbackFunctions[callbackIndex] != nil { + // Find the first free non-zero index. + callbackIndex++ + // If the map is full, i.e. all non-zero uintptrs were + // used, we do not want to loop infinitely. We check + // if we already encountered the starting index. If + // so, we panic. In practice this is very unlikely to + // have this kind of problem since all the objects are + // unregistered at the end of the function call. + if callbackIndex == startIndex { + panic("no more space in the map to store a callback function") + } + } + callbackFunctions[callbackIndex] = f + return callbackIndex } -// logCallbackData is a container for the logging function. It might be -// tempting to just use a function pointer instead of this container, -// but passing a function pointer to void* in C possibly truncates the -// address because void* is for data pointers only and function pointers -// may be wider. -type logCallbackData struct { - logger OptimizationIterationLogger +// lookupCallback returns a callback function given an index. +func lookupCallback(i uintptr) interface{} { + callbackMutex.Lock() + defer callbackMutex.Unlock() + return callbackFunctions[i] +} + +// unregisterCallback unregisters a callback by removing it from the +// callbackFunctions map. +func unregisterCallback(i uintptr) { + callbackMutex.Lock() + defer callbackMutex.Unlock() + delete(callbackFunctions, i) } // go_objective_function_callback is an adapter between the C callback @@ -411,11 +453,11 @@ func go_objective_function_callback( // Convert inputs dim := int(dim_c) wrapCArrayAsGoSlice_Float64(point_c, dim, &point) - cbData := (*callbackData)(callbackData_c) + objective := lookupCallback(uintptr(callbackData_c)).(FunctionWithGradient) // Evaluate the objective function. Let panics propagate through // C/Fortran. - value := cbData.objective.EvaluateFunction(point) + value := objective.EvaluateFunction(point) // Convert outputs *value_c = C.double(value) @@ -442,11 +484,11 @@ func go_objective_gradient_callback( // Convert inputs dim := int(dim_c) wrapCArrayAsGoSlice_Float64(point_c, dim, &point) - cbData := (*callbackData)(callbackData_c) + objective := lookupCallback(uintptr(callbackData_c)).(FunctionWithGradient) // Evaluate the gradient of the objective function. Let panics // propagate through C/Fortran. - gradRet = cbData.objective.EvaluateGradient(point) + gradRet = objective.EvaluateGradient(point) // Convert outputs wrapCArrayAsGoSlice_Float64(gradient_c, dim, &gradient) @@ -478,11 +520,11 @@ func go_log_function_callback( wrapCArrayAsGoSlice_Float64(g_c, dim, &g) // Get the logging function from the callback data - cbData := (*logCallbackData)(logCallbackData_c) + logger := lookupCallback(uintptr(logCallbackData_c)).(OptimizationIterationLogger) // Call the logging function. Let panics propagate through // C/Fortran. - cbData.logger( + logger( &OptimizationIterationInformation{ Iteration: int(iteration_c), FEvals: int(fgEvals_c),