diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index efcf3c6..8c72409 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -98,6 +98,14 @@ def set_non_null_many_relationships(non_null_flag): use_non_null_many_relationships = non_null_flag +use_id_type_for_keys = True + + +def set_id_for_keys(id_flag): + global use_id_type_for_keys + use_id_type_for_keys = id_flag + + def get_column_doc(column): return getattr(column, "doc", None) @@ -309,18 +317,34 @@ def inner(fn): convert_sqlalchemy_composite.register = _register_composite_class +def _is_primary_or_foreign_key(column): + return getattr(column, "primary_key", False) or ( + len(getattr(column, "foreign_keys", [])) > 0 + ) + + def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): column = column_prop.columns[0] - # The converter expects a type to find the right conversion function. - # If we get an instance instead, we need to convert it to a type. - # The conversion function will still be able to access the instance via the column argument. + # We only use the converter if no type was specified using the ORMField if "type_" not in field_kwargs: - column_type = getattr(column, "type", None) - if not isinstance(column_type, type): - column_type = type(column_type) + # If the column is a primary key, we use the ID typ + if use_id_type_for_keys and _is_primary_or_foreign_key(column): + field_type = graphene.ID + else: + # The converter expects a type to find the right conversion function. + # If we get an instance instead, we need to convert it to a type. + # The conversion function will still be able to access the instance via the column argument. + column_type = getattr(column, "type", None) + if not isinstance(column_type, type): + column_type = type(column_type) + + field_type = convert_sqlalchemy_type( + column_type, column=column, registry=registry + ) + field_kwargs.setdefault( "type_", - convert_sqlalchemy_type(column_type, column=column, registry=registry), + field_type, ) field_kwargs.setdefault("required", not is_column_nullable(column)) field_kwargs.setdefault("description", get_column_doc(column)) @@ -444,10 +468,6 @@ def convert_column_to_int_or_id( registry: Registry = None, **kwargs, ): - # fixme drop the primary key processing from here in another pr - if column is not None: - if getattr(column, "primary_key", False) is True: - return graphene.ID return graphene.Int diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index e62e07d..1c741a4 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -23,7 +23,7 @@ convert_sqlalchemy_hybrid_method, convert_sqlalchemy_relationship, convert_sqlalchemy_type, - set_non_null_many_relationships, + set_non_null_many_relationships, set_id_for_keys, ) from ..fields import UnsortedSQLAlchemyConnectionField, default_connection_field_factory from ..registry import Registry, get_global_registry @@ -46,11 +46,13 @@ def mock_resolver(): pass -def get_field(sqlalchemy_type, **column_kwargs): +def get_field(sqlalchemy_type, *column_args, **column_kwargs): class Model(declarative_base()): __tablename__ = "model" id_ = Column(types.Integer, primary_key=True) - column = Column(sqlalchemy_type, doc="Custom Help Text", **column_kwargs) + column = Column( + sqlalchemy_type, *column_args, doc="Custom Help Text", **column_kwargs + ) column_prop = inspect(Model).column_attrs["column"] return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver) @@ -396,12 +398,34 @@ def test_should_integer_convert_int(): assert get_field(types.Integer()).type == graphene.Int -def test_should_primary_integer_convert_id(): +def test_should_key_integer_convert_id(): assert get_field(types.Integer(), primary_key=True).type == graphene.NonNull( graphene.ID ) +def test_should_key_integer_convert_integer_with_setting(): + set_id_for_keys(False) + assert get_field(types.Integer(), primary_key=True).type == graphene.NonNull( + graphene.Int + ) + +def test_should_primary_string_convert_id(): + assert get_field(types.String(), primary_key=True).type == graphene.NonNull( + graphene.ID + ) + + +def test_should_primary_uuid_convert_id(): + assert get_field(sqa_utils.UUIDType, primary_key=True).type == graphene.NonNull( + graphene.ID + ) + + +def test_should_foreign_key_convert_id(): + assert get_field(types.Integer(), ForeignKey("model.id_")).type == graphene.ID + + def test_should_boolean_convert_boolean(): assert get_field(types.Boolean()).type == graphene.Boolean