Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow upcasting #944

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions include/behaviortree_cpp/basic_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,21 @@ class TypeInfo
template <typename T>
static TypeInfo Create()
{
// store the base class typeid if specialized
if constexpr(is_shared_ptr<T>::value)
{
using Elem = typename T::element_type;
using Base = typename any_cast_base<Elem>::type;

if constexpr(!std::is_same_v<Base, void>)
{
static_assert(is_polymorphic_safe_v<Base>, "TypeInfo Base trait specialization "
"must be "
"polymorphic");
return TypeInfo{ typeid(std::shared_ptr<Base>),
GetAnyFromStringFunctor<std::shared_ptr<Base>>() };
}
}
return TypeInfo{ typeid(T), GetAnyFromStringFunctor<T>() };
}

Expand Down Expand Up @@ -452,7 +467,8 @@ template <typename T = AnyTypeAllowed>
}
else
{
out = { sname, PortInfo(direction, typeid(T), GetAnyFromStringFunctor<T>()) };
auto type_info = TypeInfo::Create<T>();
out = { sname, PortInfo(direction, type_info.type(), type_info.converter()) };
}
if(!description.empty())
{
Expand Down Expand Up @@ -501,7 +517,6 @@ BidirectionalPort(StringView name, StringView description = {})

namespace details
{

template <typename T = AnyTypeAllowed, typename DefaultT = T>
[[nodiscard]] inline std::pair<std::string, PortInfo>
PortWithDefault(PortDirection direction, StringView name, const DefaultT& default_value,
Expand Down
6 changes: 4 additions & 2 deletions include/behaviortree_cpp/blackboard.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

namespace BT
{

/// This type contains a pointer to Any, protected
/// with a locked mutex as long as the object is in scope
using AnyPtrLocked = LockedPtr<Any>;
Expand Down Expand Up @@ -257,8 +256,11 @@ inline void Blackboard::set(const std::string& key, const T& value)

std::type_index previous_type = entry.info.type();

// allow matching if any is of the same base
const auto current_type = TypeInfo::Create<T>().type();

// check type mismatch
if(previous_type != std::type_index(typeid(T)) && previous_type != new_value.type())
if(previous_type != current_type && previous_type != new_value.type())
{
bool mismatching = true;
if(std::is_constructible<StringView, T>::value)
Expand Down
153 changes: 152 additions & 1 deletion include/behaviortree_cpp/utils/safe_any.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,52 @@

namespace BT
{

static std::type_index UndefinedAnyType = typeid(nullptr);

template <typename T>
struct any_cast_base
{
using type = void; // Default: no base known, fallback to default any storage
};

// Trait to detect std::shared_ptr types.
template <typename T>
struct is_shared_ptr : std::false_type
{
};

template <typename U>
struct is_shared_ptr<std::shared_ptr<U>> : std::true_type
{
};

// Trait to detect if a type is complete
template <typename T, typename = void>
struct is_complete : std::false_type
{
};

template <typename T>
struct is_complete<T, decltype(void(sizeof(T)))> : std::true_type
{
};

// Trait to detect if a trait is complete and polymorphic
template <typename T, typename = void>
struct is_polymorphic_safe : std::false_type
{
};

// Specialization only enabled if T is complete
template <typename T>
struct is_polymorphic_safe<T, std::enable_if_t<is_complete<T>::value>>
: std::integral_constant<bool, std::is_polymorphic<T>::value>
{
};

template <typename T>
inline constexpr bool is_polymorphic_safe_v = is_polymorphic_safe<T>::value;

// Rational: since type erased numbers will always use at least 8 bytes
// it is faster to cast everything to either double, uint64_t or int64_t.
class Any
Expand Down Expand Up @@ -107,6 +150,26 @@ class Any
Any(const std::type_index& type) : _original_type(type)
{}

// default for shared pointers
template <typename T>
explicit Any(const std::shared_ptr<T>& value)
: _original_type(typeid(std::shared_ptr<T>))
{
using Base = typename any_cast_base<T>::type;

// store as base class if specialized
if constexpr(!std::is_same_v<Base, void>)
{
static_assert(is_polymorphic_safe_v<Base>, "Any Base trait specialization must be "
"polymorphic");
_any = std::static_pointer_cast<Base>(value);
}
else
{
_any = value;
}
}

// default for other custom types
template <typename T>
explicit Any(const T& value, EnableNonIntegral<T> = 0)
Expand Down Expand Up @@ -158,6 +221,9 @@ class Any
// Method to access the value by pointer.
// It will return nullptr, if the user try to cast it to a
// wrong type or if Any was empty.
//
// WARNING: The returned pointer may alias internal cache and be invalidated by subsequent castPtr() calls.
// Do not store it long-term. Applies only to shared_ptr<Derived> where Derived is polymorphic and base-registered.
template <typename T>
[[nodiscard]] T* castPtr()
{
Expand Down Expand Up @@ -189,6 +255,52 @@ class Any
"tea"
"d");

// Special case: applies only when requesting shared_ptr<Derived> and Derived is polymorphic
// with a registered base via any_cast_base.
if constexpr(is_shared_ptr<T>::value)
{
using Derived = typename T::element_type;
using Base = typename any_cast_base<Derived>::type;

if constexpr(is_polymorphic_safe_v<Derived> && !std::is_same_v<Base, void>)
{
try
{
// Attempt to retrieve the stored shared_ptr<Base> from the Any container
auto base_ptr = linb::any_cast<std::shared_ptr<Base>>(&_any);
if(!base_ptr)
return nullptr;

// Case 1: If Base and Derived are the same, no casting is needed
if constexpr(std::is_same_v<Base, Derived>)
{
return reinterpret_cast<T*>(base_ptr);
}

// Case 2: Originally stored as shared_ptr<Derived>
if(_original_type == typeid(std::shared_ptr<Derived>))
{
_cached_derived_ptr = std::static_pointer_cast<Derived>(*base_ptr);
return reinterpret_cast<T*>(&_cached_derived_ptr);
}

// Case 3: Fallback to dynamic cast
auto derived_ptr = std::dynamic_pointer_cast<Derived>(*base_ptr);
if(derived_ptr)
{
_cached_derived_ptr = derived_ptr;
return reinterpret_cast<T*>(&_cached_derived_ptr);
}
}
catch(...)
{
return nullptr;
}

return nullptr;
}
}

return _any.empty() ? nullptr : linb::any_cast<T>(&_any);
}

Expand All @@ -212,6 +324,7 @@ class Any
private:
linb::any _any;
std::type_index _original_type;
mutable std::shared_ptr<void> _cached_derived_ptr = nullptr;

//----------------------------

Expand Down Expand Up @@ -513,6 +626,44 @@ inline nonstd::expected<T, std::string> Any::tryCast() const
throw std::runtime_error("Any::cast failed because it is empty");
}

// special case: T is a shared_ptr to a registered polymorphic type.
// The stored value is a shared_ptr<Base>, but the user is requesting shared_ptr<Derived>.
// Perform safe downcasting (static or dynamic) from Base to Derived if applicable.
if constexpr(is_shared_ptr<T>::value)
{
using Derived = typename T::element_type;
using Base = typename any_cast_base<Derived>::type;

if constexpr(is_polymorphic_safe_v<Derived> && !std::is_same_v<Base, void>)
{
// Attempt to retrieve the stored shared_ptr<Base> from the Any container
auto base_ptr = linb::any_cast<std::shared_ptr<Base>>(_any);
if(!base_ptr)
{
throw std::runtime_error("Any::cast cannot cast to shared_ptr<Base> class");
}

// Case 1: If Base and Derived are the same, no casting is needed
if constexpr(std::is_same_v<T, std::shared_ptr<Base>>)
{
return base_ptr;
}

// Case 2: If the original stored type was shared_ptr<Derived>, we can safely static_cast
if(_original_type == typeid(std::shared_ptr<Derived>))
{
return std::static_pointer_cast<Derived>(base_ptr);
}

// Case 3: Otherwise, attempt a dynamic cast from Base to Derived
auto derived_ptr = std::dynamic_pointer_cast<Derived>(base_ptr);
if(!derived_ptr)
throw std::runtime_error("Any::cast Dynamic cast failed, types are not related");

return derived_ptr;
}
}

if(castedType() == typeid(T))
{
return linb::any_cast<T>(_any);
Expand Down
78 changes: 78 additions & 0 deletions tests/gtest_any.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <gtest/gtest.h>

#include <behaviortree_cpp/utils/safe_any.hpp>
#include "greeter_test.h"

using namespace BT;

Expand Down Expand Up @@ -249,4 +250,81 @@ TEST(Any, Cast)
Any a(v);
EXPECT_EQ(a.cast<std::vector<int>>(), v);
}

/// Issue 943
// Type casting: polymorphic class w/ registered base class
{
auto g = std::make_shared<Greeter>();
Any any_g(g);
EXPECT_NO_THROW(auto res = any_g.cast<Greeter::Ptr>());
EXPECT_ANY_THROW(auto res = any_g.cast<HelloGreeter::Ptr>());
EXPECT_ANY_THROW(auto res = any_g.cast<FancyHelloGreeter::Ptr>());
EXPECT_TRUE(any_g.castPtr<Greeter::Ptr>());
EXPECT_FALSE(any_g.castPtr<HelloGreeter::Ptr>());
EXPECT_FALSE(any_g.castPtr<FancyHelloGreeter::Ptr>());

auto hg = std::make_shared<HelloGreeter>();
Any any_hg(hg);
EXPECT_NO_THROW(auto res = any_hg.cast<Greeter::Ptr>());
EXPECT_NO_THROW(auto res = any_hg.cast<HelloGreeter::Ptr>());
EXPECT_ANY_THROW(auto res = any_hg.cast<FancyHelloGreeter::Ptr>());
EXPECT_TRUE(any_hg.castPtr<Greeter::Ptr>());
EXPECT_TRUE(any_hg.castPtr<HelloGreeter::Ptr>());
EXPECT_FALSE(any_hg.castPtr<FancyHelloGreeter::Ptr>());

auto fhg = std::make_shared<FancyHelloGreeter>();
Any any_fhg(fhg);
EXPECT_NO_THROW(auto res = any_fhg.cast<Greeter::Ptr>());
EXPECT_NO_THROW(auto res = any_fhg.cast<HelloGreeter::Ptr>());
EXPECT_NO_THROW(auto res = any_fhg.cast<FancyHelloGreeter::Ptr>());
EXPECT_TRUE(any_fhg.castPtr<Greeter::Ptr>());
EXPECT_TRUE(any_fhg.castPtr<HelloGreeter::Ptr>());
EXPECT_TRUE(any_fhg.castPtr<FancyHelloGreeter::Ptr>());

// Try to upcast to an incorrectly registered base
auto u = std::make_shared<Unwelcomer>();

// OK, fails to compile -> invalid static cast
// Any any_u(u);
// EXPECT_ANY_THROW(auto res = any_g.cast<Unwelcomer::Ptr>());
}

// Type casting: polymorphic class w/o registered base class
{
auto g = std::make_shared<GreeterNoReg>();
Any any_g(g);
EXPECT_NO_THROW(auto res = any_g.cast<GreeterNoReg::Ptr>());
EXPECT_ANY_THROW(auto res = any_g.cast<HelloGreeterNoReg::Ptr>());
EXPECT_TRUE(any_g.castPtr<GreeterNoReg::Ptr>());
EXPECT_FALSE(any_g.castPtr<HelloGreeterNoReg::Ptr>());

auto hg = std::make_shared<HelloGreeterNoReg>();
Any any_hg(hg);
EXPECT_ANY_THROW(auto res = any_hg.cast<GreeterNoReg::Ptr>());
EXPECT_NO_THROW(auto res = any_hg.cast<HelloGreeterNoReg::Ptr>());
EXPECT_FALSE(any_hg.castPtr<GreeterNoReg::Ptr>());
EXPECT_TRUE(any_hg.castPtr<HelloGreeterNoReg::Ptr>());
}

// Type casting: non polymorphic class w/ registered base class
{
// OK: static_assert(std::is_polymorphic_v<Base>, "Base must be polymorphic")
}

// Type casting: non polymorphic class w/o registered base class
{
auto g = std::make_shared<GreeterNoPolyReg>();
Any any_g(g);
EXPECT_NO_THROW(auto res = any_g.cast<GreeterNoPolyReg::Ptr>());
EXPECT_ANY_THROW(auto res = any_g.cast<HelloGreeterNoPolyReg::Ptr>());
EXPECT_TRUE(any_g.castPtr<GreeterNoPolyReg::Ptr>());
EXPECT_FALSE(any_g.castPtr<HelloGreeterNoPolyReg::Ptr>());

auto hg = std::make_shared<HelloGreeterNoPolyReg>();
Any any_hg(hg);
EXPECT_ANY_THROW(auto res = any_hg.cast<GreeterNoPolyReg::Ptr>());
EXPECT_NO_THROW(auto res = any_hg.cast<HelloGreeterNoPolyReg::Ptr>());
EXPECT_FALSE(any_hg.castPtr<GreeterNoPolyReg::Ptr>());
EXPECT_TRUE(any_hg.castPtr<HelloGreeterNoPolyReg::Ptr>());
}
}
Loading
Loading