diff --git a/libs/code-actions/src/code_actions_provider.rs b/libs/code-actions/src/code_actions_provider.rs index 4dd138e..293e4fc 100644 --- a/libs/code-actions/src/code_actions_provider.rs +++ b/libs/code-actions/src/code_actions_provider.rs @@ -74,3 +74,23 @@ impl CodeActionsProvider { refactors } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_set_base_path() { + let provider = CodeActionsProvider::new(); + provider.set_base_path("test".to_string()); + assert_eq!(*provider.base_path.read().unwrap(), "test"); + } + + #[test] + fn test_update_file_content() { + let provider = CodeActionsProvider::new(); + provider.set_base_path("tests".to_string()); + let result = provider.update_file_content(); + assert!(!result.is_ok()); + } +} diff --git a/libs/code-actions/src/completions/auto_complete_provider.rs b/libs/code-actions/src/completions/auto_complete_provider.rs index 9d59309..eb2dbc3 100644 --- a/libs/code-actions/src/completions/auto_complete_provider.rs +++ b/libs/code-actions/src/completions/auto_complete_provider.rs @@ -104,3 +104,26 @@ impl AutoCompleteProvider { vec![] } } + +#[cfg(test)] +mod test { + use crate::test_utils::{create_test_ast_file, create_test_ast_file_struct_definition}; + + use super::*; + + #[test] + fn test_get_suggestions() { + let provider = AutoCompleteProvider::new(); + let uri = "test.sol"; + let position = Position { + line: 10, + column: 15, + }; + let files = vec![ + create_test_ast_file(), + create_test_ast_file_struct_definition(), + ]; + let completes = provider.get_suggestions(uri, position, &files); + assert_eq!(completes.len(), 4); + } +} diff --git a/libs/code-actions/src/lib.rs b/libs/code-actions/src/lib.rs index 1172f3b..294d964 100644 --- a/libs/code-actions/src/lib.rs +++ b/libs/code-actions/src/lib.rs @@ -2,6 +2,7 @@ mod code_actions_provider; mod completions; mod error; mod references; +mod test_utils; mod types; mod utils; diff --git a/libs/code-actions/src/references/definition_visitor.rs b/libs/code-actions/src/references/definition_visitor.rs index f6cec4d..92faafa 100644 --- a/libs/code-actions/src/references/definition_visitor.rs +++ b/libs/code-actions/src/references/definition_visitor.rs @@ -77,3 +77,199 @@ impl DefinitionVisitor { self.node.clone() } } + +#[cfg(test)] +mod test { + use crate::test_utils::{ + create_test_ast_file_contract_definition, create_test_ast_file_enum_definition, + create_test_ast_file_enum_value, create_test_ast_file_event_definition, + create_test_ast_file_function_definition, create_test_ast_file_modifier_definition, + create_test_ast_file_struct_definition, create_test_ast_file_variable_declaration, + }; + + use super::*; + + #[test] + fn test_find_contract_definition() { + let id = 1; + let file = create_test_ast_file_contract_definition(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::ContractDefinition(contract)) = node { + assert_eq!(contract.id, id); + } else { + panic!("Expected ContractDefinition node"); + } + } + + #[test] + fn test_find_contract_definition_not_found() { + let id = 0; + let file = create_test_ast_file_contract_definition(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_function_definition() { + let id = 2; + let file = create_test_ast_file_function_definition(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::FunctionDefinition(function)) = node { + assert_eq!(function.id, id); + } else { + panic!("Expected FunctionDefinition node"); + } + } + + #[test] + fn test_find_function_definition_not_found() { + let id = 0; + let file = create_test_ast_file_function_definition(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_modifier_definition() { + let id = 4; + let file = create_test_ast_file_modifier_definition(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::ModifierDefinition(modifier)) = node { + assert_eq!(modifier.id, id); + } else { + panic!("Expected ModifierDefinition node"); + } + } + + #[test] + fn test_find_modifier_definition_not_found() { + let id = 0; + let file = create_test_ast_file_modifier_definition(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_struct_definition() { + let id = 5; + let file = create_test_ast_file_struct_definition(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::StructDefinition(struct_def)) = node { + assert_eq!(struct_def.id, id); + } else { + panic!("Expected StructDefinition node"); + } + } + + #[test] + fn test_find_struct_definition_not_found() { + let id = 0; + let file = create_test_ast_file_struct_definition(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_enum_definition() { + let id = 6; + let file = create_test_ast_file_enum_definition(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::EnumDefinition(enum_def)) = node { + assert_eq!(enum_def.id, id); + } else { + panic!("Expected EnumDefinition node"); + } + } + + #[test] + fn test_find_enum_definition_not_found() { + let id = 0; + let file = create_test_ast_file_enum_definition(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_variable_declaration() { + let id = 3; + let file = create_test_ast_file_variable_declaration(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::VariableDeclaration(variable)) = node { + assert_eq!(variable.id, id); + } else { + panic!("Expected VariableDeclaration node"); + } + } + + #[test] + fn test_find_variable_declaration_not_found() { + let id = 0; + let file = create_test_ast_file_variable_declaration(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_enum_value() { + let id = 8; + let file = create_test_ast_file_enum_value(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::EnumValue(enum_def)) = node { + assert_eq!(enum_def.id, id); + } else { + panic!("Expected EnumDefinition node"); + } + } + + #[test] + fn test_find_enum_value_not_found() { + let id = 0; + let file = create_test_ast_file_enum_value(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_event_definition() { + let id = 7; + let file = create_test_ast_file_event_definition(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::EventDefinition(event)) = node { + assert_eq!(event.id, id); + } else { + panic!("Expected EventDefinition node"); + } + } + + #[test] + fn test_find_event_definition_not_found() { + let id = 0; + let file = create_test_ast_file_event_definition(); + let mut visitor = DefinitionVisitor::new(id); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } +} diff --git a/libs/code-actions/src/references/position_node_visitor.rs b/libs/code-actions/src/references/position_node_visitor.rs index 8df35ce..fdeb560 100644 --- a/libs/code-actions/src/references/position_node_visitor.rs +++ b/libs/code-actions/src/references/position_node_visitor.rs @@ -60,7 +60,7 @@ impl<'ast> Visit<'ast> for PositionNodeVisitor { fn visit_enum_definition(&mut self, enum_def: &'ast EnumDefinition) { if is_node_in_range(&enum_def.src, &self.position, &self.source) { - self.above_node.clone_from(&self.node); + self.above_node = self.node.clone(); self.node = Some(InteractableNode::EnumDefinition(enum_def.clone())); } visit::visit_enum_definition(self, enum_def); @@ -183,3 +183,469 @@ impl PositionNodeVisitor { self.node.clone() } } + +#[cfg(test)] +mod test { + use super::*; + use crate::test_utils::{ + create_test_ast_file_contract_definition, create_test_ast_file_enum_definition, + create_test_ast_file_enum_value, create_test_ast_file_error_definition, + create_test_ast_file_event_definition, create_test_ast_file_function_call, + create_test_ast_file_function_definition, create_test_ast_file_identifier, + create_test_ast_file_inheritance_specifier, create_test_ast_file_member_access, + create_test_ast_file_modifier_definition, create_test_ast_file_modifier_invocation, + create_test_ast_file_new_expression, create_test_ast_file_struct_definition, + create_test_ast_file_user_defined_type_name, create_test_ast_file_using_for_directive, + create_test_ast_file_variable_declaration, + }; + + #[test] + fn test_find_contract_definition() { + let file = create_test_ast_file_contract_definition(); + let position = Position { + line: 3, + column: 10, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::ContractDefinition(contract)) = node { + assert_eq!(contract.name, "Test"); + } else { + panic!("Expected ContractDefinition, got {:?}", node); + } + } + + #[test] + fn test_find_contract_definition_wrong_position() { + let file = create_test_ast_file_contract_definition(); + let position = Position { line: 1, column: 1 }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_function_definition() { + let file = create_test_ast_file_function_definition(); + let position = Position { + line: 3, + column: 10, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::FunctionDefinition(function)) = node { + assert_eq!(function.name, "notUsed"); + } else { + panic!("Expected FunctionDefinition, got {:?}", node); + } + } + + #[test] + fn test_find_function_definition_wrong_position() { + let file = create_test_ast_file_function_definition(); + let position = Position { line: 2, column: 1 }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_variable_declaration() { + let file = create_test_ast_file_variable_declaration(); + let position = Position { + line: 3, + column: 21, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::VariableDeclaration(variable)) = node { + assert_eq!(variable.name, "number"); + } else { + panic!("Expected VariableDeclaration, got {:?}", node); + } + } + + #[test] + fn test_find_variable_declaration_wrong_position() { + let file = create_test_ast_file_variable_declaration(); + let position = Position { line: 2, column: 1 }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_enum_definition() { + let file = create_test_ast_file_enum_definition(); + let position = Position { + line: 3, + column: 12, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::EnumDefinition(enum_def)) = node { + assert_eq!(enum_def.name, "TestEnum"); + } else { + panic!("Expected EnumDefinition, got {:?}", node); + } + } + + #[test] + fn test_find_enum_definition_wrong_position() { + let file = create_test_ast_file_enum_definition(); + let position = Position { line: 2, column: 1 }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_enum_value() { + let file = create_test_ast_file_enum_value(); + let position = Position { + line: 4, + column: 14, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::EnumValue(enum_value)) = node { + assert_eq!(enum_value.name, "TestEnumValue"); + } else { + panic!("Expected EnumValue, got {:?}", node); + } + } + + #[test] + fn test_find_enum_value_wrong_position() { + let file = create_test_ast_file_enum_value(); + let position = Position { line: 2, column: 1 }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_struct_definition() { + let file = create_test_ast_file_struct_definition(); + let position = Position { + line: 4, + column: 14, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::StructDefinition(struct_def)) = node { + assert_eq!(struct_def.name, "TestStruct"); + } else { + panic!("Expected StructDefinition, got {:?}", node); + } + } + + #[test] + fn test_find_struct_definition_wrong_position() { + let file = create_test_ast_file_struct_definition(); + let position = Position { line: 2, column: 1 }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_event_definition() { + let file = create_test_ast_file_event_definition(); + let position = Position { + line: 4, + column: 18, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::EventDefinition(event)) = node { + assert_eq!(event.name, "TestEvent"); + } else { + panic!("Expected EventDefinition, got {:?}", node); + } + } + + #[test] + fn test_find_event_definition_wrong_position() { + let file = create_test_ast_file_event_definition(); + let position = Position { line: 2, column: 1 }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_using_for_directive() { + let file = create_test_ast_file_using_for_directive(); + let position = Position { + line: 4, + column: 19, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::UsingForDirective(_)) = node { + assert!(true) + } else { + panic!("Expected UsingForDirective, got {:?}", node); + } + } + + #[test] + fn test_find_using_for_directive_wrong_position() { + let file = create_test_ast_file_using_for_directive(); + let position = Position { line: 2, column: 1 }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_error_defintion() { + let file = create_test_ast_file_error_definition(); + let position = Position { + line: 3, + column: 15, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::ErrorDefinition(error)) = node { + assert_eq!(error.name, "TestError"); + } else { + panic!("Expected ErrorDefinition, got {:?}", node); + } + } + + #[test] + fn test_find_error_defintion_wrong_position() { + let file = create_test_ast_file_error_definition(); + let position = Position { line: 2, column: 1 }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_function_call() { + let file = create_test_ast_file_function_call(); + let position = Position { + line: 5, + column: 15, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::FunctionCall(_)) = node { + assert!(true); + } else { + panic!("Expected FunctionCall, got {:?}", node); + } + } + + #[test] + fn test_find_function_call_wrong_position() { + let file = create_test_ast_file_function_call(); + let position = Position { line: 2, column: 1 }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_modifier_definition() { + let file = create_test_ast_file_modifier_definition(); + let position = Position { + line: 4, + column: 18, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::ModifierDefinition(modifier)) = node { + assert_eq!(modifier.name, "modifier"); + } else { + panic!("Expected ModifierDefinition, got {:?}", node); + } + } + + #[test] + fn test_find_modifier_definition_wrong_position() { + let file = create_test_ast_file_modifier_definition(); + let position = Position { line: 2, column: 1 }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_modifier_invocation() { + let file = create_test_ast_file_modifier_invocation(); + let position = Position { + line: 4, + column: 29, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::ModifierInvocation(_)) = node { + assert!(true); + } else { + panic!("Expected ModifierInvocation, got {:?}", node); + } + } + + #[test] + fn test_find_modifier_invocation_wrong_position() { + let file = create_test_ast_file_modifier_invocation(); + let position = Position { line: 2, column: 1 }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_inheritance_specifier() { + let file = create_test_ast_file_inheritance_specifier(); + let position = Position { + line: 3, + column: 25, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::InheritanceSpecifier(_)) = node { + assert!(true); + } else { + panic!("Expected InheritanceSpecifier, got {:?}", node); + } + } + + #[test] + fn test_find_inheritance_specifier_wrong_position() { + let file = create_test_ast_file_inheritance_specifier(); + let position = Position { line: 2, column: 1 }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_identifier() { + let file = create_test_ast_file_identifier(); + let position = Position { + line: 6, + column: 15, + }; + + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::Identifier(identifier)) = node { + assert_eq!(identifier.name, "number"); + } else { + panic!("Expected Identifier, got {:?}", node); + } + } + + #[test] + fn test_find_indentifier_wrong_position() { + let file = create_test_ast_file_identifier(); + let position = Position { line: 2, column: 1 }; + + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_member_access() { + let file = create_test_ast_file_member_access(); + let position = Position { + line: 6, + column: 22, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::MemberAccess(member_access)) = node { + assert_eq!(member_access.member_name, "member"); + } else { + panic!("Expected MemberAccess, got {:?}", node); + } + } + + #[test] + fn test_find_member_access_wrong_position() { + let file = create_test_ast_file_member_access(); + let position = Position { line: 2, column: 1 }; + + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_new_expression() { + let file = create_test_ast_file_new_expression(); + let position = Position { + line: 6, + column: 20, + }; + + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::NewExpression(new_expression, _)) = node { + assert_eq!( + new_expression.node_type, + NewExpressionNodeType::NewExpression + ); + } else { + panic!("Expected NewExpression, got {:?}", node); + } + } + + #[test] + fn test_find_new_expression_wrong_position() { + let file = create_test_ast_file_new_expression(); + let position = Position { line: 2, column: 1 }; + + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } + + #[test] + fn test_find_user_defined_type_name() { + let file = create_test_ast_file_user_defined_type_name(); + let position = Position { + line: 6, + column: 45, + }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_some()); + if let Some(InteractableNode::UserDefinedTypeName(udt)) = node { + assert_eq!(udt.name, Some("TestStruct".to_string())); + } else { + panic!("Expected UserDefinedTypeName, got {:?}", node); + } + } + + #[test] + fn test_find_user_defined_type_name_wrong_position() { + let file = create_test_ast_file_user_defined_type_name(); + let position = Position { line: 2, column: 1 }; + let mut visitor = PositionNodeVisitor::new(position, &file.file.content); + let node = visitor.find(&file.ast); + assert!(node.is_none()); + } +} diff --git a/libs/code-actions/src/references/reference_provider.rs b/libs/code-actions/src/references/reference_provider.rs index 9980156..3b08cda 100644 --- a/libs/code-actions/src/references/reference_provider.rs +++ b/libs/code-actions/src/references/reference_provider.rs @@ -106,3 +106,35 @@ impl ReferenceProvider { references } } + +#[cfg(test)] +mod test { + use super::*; + use crate::test_utils::create_test_ast_file_contract_definition; + + #[test] + fn test_get_definition() { + let provider = ReferenceProvider::new(); + let uri = "test.sol"; + let position = Position { + line: 3, + column: 10, + }; + let files = vec![create_test_ast_file_contract_definition()]; + let base_path = "/home/user/project"; + let location = provider.get_definition(uri, position, &files, base_path); + assert!(location.is_some()); + let location = location.unwrap(); + assert_eq!(location.uri, uri); + } + + #[test] + fn test_get_references() { + let provider = ReferenceProvider::new(); + let uri = "test.sol"; + let position = Position { line: 0, column: 0 }; + let files = vec![create_test_ast_file_contract_definition()]; + let references = provider.get_references(uri, position, &files); + assert_eq!(references.len(), 0); + } +} diff --git a/libs/code-actions/src/references/usage_visitor.rs b/libs/code-actions/src/references/usage_visitor.rs index 3e0756f..56c8302 100644 --- a/libs/code-actions/src/references/usage_visitor.rs +++ b/libs/code-actions/src/references/usage_visitor.rs @@ -103,3 +103,82 @@ impl UsageVisitor { self.to_find.clone() } } + +#[cfg(test)] +mod test { + use super::*; + use crate::test_utils::{ + create_test_ast_file_identifier, create_test_ast_file_identifier_path, + create_test_ast_file_import_directive, create_test_ast_file_member_access, + create_test_ast_file_user_defined_type_name, + }; + use crate::types::InteractableNode; + + #[test] + fn test_find_usages_identifier() { + let file = create_test_ast_file_identifier(); + let id = 30; + let mut visitor = UsageVisitor::new(id); + let usages = visitor.find(&file.ast); + assert_eq!(usages.len(), 2); + if let InteractableNode::Identifier(identifier) = &usages[0] { + assert_eq!(identifier.referenced_declaration, Some(id)); + assert_eq!(identifier.name, "number"); + } else { + panic!("Expected IdentifierPath, got: {:?}", usages[0]); + } + } + + #[test] + fn test_find_usages_identifier_path() { + let file = create_test_ast_file_identifier_path(); + let id = 15; + let mut visitor = UsageVisitor::new(id); + let usages = visitor.find(&file.ast); + assert_eq!(usages.len(), 1); + if let InteractableNode::IdentifierPath(identifier_path) = &usages[0] { + assert_eq!(identifier_path.referenced_declaration, id); + assert_eq!(identifier_path.name, "IdPath"); + } else { + panic!("Expected IdentifierPath, got: {:?}", usages[0]); + } + } + + #[test] + fn test_find_usages_with_imports() { + let file = create_test_ast_file_import_directive(); + let id = -1; + let mut visitor = UsageVisitor::new(id); + let usages = visitor.find(&file.ast); + assert_eq!(usages.len(), 0); + } + + #[test] + fn test_find_usages_user_defined_type_name() { + let file = create_test_ast_file_user_defined_type_name(); + let id = 5; + let mut visitor = UsageVisitor::new(id); + let usages = visitor.find(&file.ast); + assert_eq!(usages.len(), 1); + if let InteractableNode::UserDefinedTypeName(udt) = &usages[0] { + assert_eq!(udt.referenced_declaration, id); + } else { + panic!("Expected UserDefinedTypeName, got: {:?}", usages[0]); + } + } + + #[test] + fn test_find_usages_member_access() { + let file = create_test_ast_file_member_access(); + let id = 3; + let mut visitor = UsageVisitor::new(id); + let usages = visitor.find(&file.ast); + eprintln!("{:?}", usages); + assert_eq!(usages.len(), 1); + if let InteractableNode::MemberAccess(member) = &usages[0] { + assert_eq!(member.referenced_declaration, Some(id)); + } else { + panic!("Expected MemberAccess, got: {:?}", usages[0]); + } + } +} diff --git a/libs/code-actions/src/test_utils.rs b/libs/code-actions/src/test_utils.rs new file mode 100644 index 0000000..f225d9b --- /dev/null +++ b/libs/code-actions/src/test_utils.rs @@ -0,0 +1,1319 @@ +use std::{collections::HashMap, str::FromStr}; + +use osmium_libs_solidity_ast_extractor::{kw::create, types::SolidityAstFile}; +use solc_ast_rs_types::types::*; + +#[allow(dead_code)] +pub fn create_test_contract_definition() -> ContractDefinition { + ContractDefinition { + id: 1, + name: "Test".to_string(), + src: SourceLocation::from_str("29:215:1").unwrap(), + name_location: None, + abstract_: false, + base_contracts: vec![], + canonical_name: None, + contract_dependencies: vec![], + contract_kind: ContractDefinitionContractKind::Contract, + documentation: None, + fully_implemented: true, + internal_function_i_ds: HashMap::new(), + linearized_base_contracts: vec![], + node_type: ContractDefinitionNodeType::ContractDefinition, + nodes: vec![], + scope: 0, + used_errors: vec![], + used_events: vec![], + } +} + +#[allow(dead_code)] +pub fn create_test_function_definition() -> FunctionDefinition { + FunctionDefinition { + id: 2, + name: "notUsed".to_string(), + src: SourceLocation::from_str("152:86:1").unwrap(), + name_location: None, + visibility: Visibility::Public, + state_mutability: StateMutability::Nonpayable, + parameters: ParameterList { + id: 1, + node_type: ParameterListNodeType::ParameterList, + parameters: vec![], + src: SourceLocation::from_str("1:1:1").unwrap(), + }, + return_parameters: ParameterList { + id: 1, + node_type: ParameterListNodeType::ParameterList, + parameters: vec![], + src: SourceLocation::from_str("4:1:1").unwrap(), + }, + body: None, + kind: FunctionDefinitionKind::Function, + node_type: FunctionDefinitionNodeType::FunctionDefinition, + scope: 1, + implemented: true, + base_functions: None, + documentation: None, + function_selector: None, + modifiers: vec![], + overrides: None, + virtual_: false, + } +} + +#[allow(dead_code)] +pub fn create_test_variable_declaration() -> VariableDeclaration { + VariableDeclaration { + id: 3, + name: "number".to_string(), + src: SourceLocation::from_str("56:21:1").unwrap(), + visibility: Visibility::Public, + constant: false, + indexed: None, + base_functions: None, + documentation: None, + function_selector: None, + mutability: Mutability::Mutable, + name_location: None, + node_type: VariableDeclarationNodeType::VariableDeclaration, + overrides: Some(OverrideSpecifier { + src: SourceLocation::from_str("7:1:1").unwrap(), + id: 1, + node_type: OverrideSpecifierNodeType::OverrideSpecifier, + overrides: OverrideSpecifierOverrides::UserDefinedTypeNames(vec![ + UserDefinedTypeName { + id: 1, + node_type: UserDefinedTypeNameNodeType::UserDefinedTypeName, + referenced_declaration: 3, + src: SourceLocation::from_str("10:1:1").unwrap(), + type_descriptions: TypeDescriptions { + type_string: Some("uint256".to_string()), + type_identifier: None, + }, + contract_scope: (), + name: Some("Test".to_string()), + path_node: None, + }, + ]), + }), + scope: 2, + state_variable: false, + storage_location: StorageLocation::Default, + type_descriptions: TypeDescriptions { + type_string: Some("uint256".to_string()), + type_identifier: None, + }, + type_name: None, + value: Some(Expression::Identifier(Identifier { + id: 1, + node_type: IdentifierNodeType::Identifier, + name: "number".to_string(), + src: SourceLocation::from_str("13:1:30").unwrap(), + referenced_declaration: Some(30), + type_descriptions: TypeDescriptions { + type_string: Some("uint256".to_string()), + type_identifier: None, + }, + overloaded_declarations: vec![], + argument_types: None, + })), + } +} + +#[allow(dead_code)] +pub fn create_test_modifier_definition() -> ModifierDefinition { + ModifierDefinition { + id: 4, + name: "modifier".to_string(), + src: SourceLocation::from_str("16:1:1").unwrap(), + name_location: None, + visibility: Visibility::Public, + documentation: None, + node_type: ModifierDefinitionNodeType::ModifierDefinition, + parameters: ParameterList { + id: 1, + node_type: ParameterListNodeType::ParameterList, + parameters: vec![], + src: SourceLocation::from_str("19:1:1").unwrap(), + }, + virtual_: false, + base_modifiers: None, + body: Block { + id: 1, + src: SourceLocation::from_str("22:1:1").unwrap(), + statements: None, + node_type: BlockNodeType::Block, + documentation: None, + }, + overrides: None, + } +} + +#[allow(dead_code)] +pub fn create_test_struct_definition() -> StructDefinition { + StructDefinition { + id: 5, + name: "TestStruct".to_string(), + src: SourceLocation::from_str("25:1:1").unwrap(), + name_location: None, + documentation: None, + node_type: StructDefinitionNodeType::StructDefinition, + scope: 1, + members: vec![], + canonical_name: "TestStruct".to_string(), + visibility: Visibility::Public, + } +} + +#[allow(dead_code)] +pub fn create_test_enum_definition() -> EnumDefinition { + EnumDefinition { + id: 6, + name: "TestEnum".to_string(), + src: SourceLocation::from_str("28:1:1").unwrap(), + name_location: None, + documentation: None, + node_type: EnumDefinitionNodeType::EnumDefinition, + members: vec![create_test_enum_value()], + canonical_name: "TestEnum".to_string(), + } +} + +#[allow(dead_code)] +pub fn create_test_event_definition() -> EventDefinition { + EventDefinition { + id: 7, + name: "TestEvent".to_string(), + src: SourceLocation::from_str("31:1:1").unwrap(), + name_location: None, + documentation: None, + node_type: EventDefinitionNodeType::EventDefinition, + parameters: ParameterList { + id: 1, + node_type: ParameterListNodeType::ParameterList, + parameters: vec![], + src: SourceLocation::from_str("34:1:1").unwrap(), + }, + anonymous: false, + event_selector: None, + } +} + +#[allow(dead_code)] +pub fn create_test_enum_value() -> EnumValue { + EnumValue { + id: 8, + name: "TestEnumValue".to_string(), + src: SourceLocation::from_str("37:1:1").unwrap(), + name_location: None, + node_type: EnumValueNodeType::EnumValue, + } +} + +#[allow(dead_code)] +pub fn create_test_using_for_directive() -> UsingForDirective { + UsingForDirective { + id: 9, + library_name: None, + src: SourceLocation::from_str("40:1:1").unwrap(), + node_type: UsingForDirectiveNodeType::UsingForDirective, + type_name: None, + function_list: vec![], + global: None, + } +} + +#[allow(dead_code)] +pub fn create_test_import_directive() -> ImportDirective { + ImportDirective { + id: 10, + src: SourceLocation::from_str("43:1:1").unwrap(), + node_type: ImportDirectiveNodeType::ImportDirective, + unit_alias: "Alias".to_string(), + absolute_path: "/home/user/test.sol".to_string(), + file: "test.sol".to_string(), + name_location: None, + scope: 0, + source_unit: 0, + symbol_aliases: vec![], + } +} + +#[allow(dead_code)] +pub fn create_test_error_definition() -> ErrorDefinition { + ErrorDefinition { + id: 11, + name: "TestError".to_string(), + src: SourceLocation::from_str("46:1:1").unwrap(), + name_location: "Here".to_string(), + documentation: None, + node_type: ErrorDefinitionNodeType::ErrorDefinition, + error_selector: None, + parameters: ParameterList { + id: 1, + node_type: ParameterListNodeType::ParameterList, + parameters: vec![], + src: SourceLocation::from_str("49:1:1").unwrap(), + }, + } +} + +#[allow(dead_code)] +pub fn create_test_function_call() -> FunctionCall { + FunctionCall { + id: 12, + src: SourceLocation::from_str("175:10:1").unwrap(), + node_type: FunctionCallNodeType::FunctionCall, + arguments: vec![], + expression: Box::new(Expression::Identifier(Identifier { + id: 1, + node_type: IdentifierNodeType::Identifier, + name: "number".to_string(), + src: SourceLocation::from_str("55:1:1").unwrap(), + referenced_declaration: Some(2), + type_descriptions: TypeDescriptions { + type_string: Some("uint256".to_string()), + type_identifier: None, + }, + overloaded_declarations: vec![], + argument_types: None, + })), + names: vec![], + type_descriptions: TypeDescriptions { + type_string: Some("uint256".to_string()), + type_identifier: None, + }, + argument_types: None, + is_constant: false, + is_l_value: false, + is_pure: false, + kind: FunctionCallKind::FunctionCall, + l_value_requested: false, + name_locations: vec![], + try_call: false, + } +} + +#[allow(dead_code)] +pub fn create_test_modifier_invocation() -> ModifierInvocation { + ModifierInvocation { + id: 13, + src: SourceLocation::from_str("58:1:1").unwrap(), + node_type: ModifierInvocationNodeType::ModifierInvocation, + arguments: None, + modifier_name: ModifierInvocationModifierName::Identifier(Identifier { + id: 1, + node_type: IdentifierNodeType::Identifier, + name: "modifier".to_string(), + src: SourceLocation::from_str("61:1:1").unwrap(), + referenced_declaration: Some(4), + type_descriptions: TypeDescriptions { + type_string: Some("modifier".to_string()), + type_identifier: None, + }, + overloaded_declarations: vec![], + argument_types: None, + }), + kind: Some(ModifierInvocationKind::ModifierInvocation), + } +} + +#[allow(dead_code)] +pub fn create_test_inheritance_specifier() -> InheritanceSpecifier { + InheritanceSpecifier { + id: 14, + src: SourceLocation::from_str("64:1:1").unwrap(), + node_type: InheritanceSpecifierNodeType::InheritanceSpecifier, + arguments: None, + base_name: InheritanceSpecifierBaseName::UserDefinedTypeName(UserDefinedTypeName { + id: 1, + node_type: UserDefinedTypeNameNodeType::UserDefinedTypeName, + referenced_declaration: 5, + src: SourceLocation::from_str("67:1:1").unwrap(), + type_descriptions: TypeDescriptions { + type_string: Some("TestStruct".to_string()), + type_identifier: None, + }, + contract_scope: (), + name: Some("TestStruct".to_string()), + path_node: None, + }), + } +} + +#[allow(dead_code)] +pub fn create_test_identifier() -> Identifier { + Identifier { + id: 15, + node_type: IdentifierNodeType::Identifier, + name: "number".to_string(), + src: SourceLocation::from_str("199:6:1").unwrap(), + referenced_declaration: Some(30), + type_descriptions: TypeDescriptions { + type_string: Some("uint256".to_string()), + type_identifier: None, + }, + overloaded_declarations: vec![], + argument_types: None, + } +} + +#[allow(dead_code)] +pub fn create_test_member_access() -> MemberAccess { + MemberAccess { + id: 16, + src: SourceLocation::from_str("73:1:1").unwrap(), + expression: Box::new(Expression::Identifier(Identifier { + id: 1, + node_type: IdentifierNodeType::Identifier, + name: "number".to_string(), + src: SourceLocation::from_str("76:1:1").unwrap(), + referenced_declaration: Some(123), + type_descriptions: TypeDescriptions { + type_string: Some("uint256".to_string()), + type_identifier: None, + }, + overloaded_declarations: vec![], + argument_types: None, + })), + member_name: "member".to_string(), + referenced_declaration: Some(3), + type_descriptions: TypeDescriptions { + type_string: Some("uint256".to_string()), + type_identifier: None, + }, + argument_types: None, + is_constant: false, + is_l_value: false, + is_pure: false, + l_value_requested: false, + member_location: None, + node_type: MemberAccessNodeType::MemberAccess, + } +} + +#[allow(dead_code)] +pub fn create_test_new_expression() -> NewExpression { + NewExpression { + id: 17, + src: SourceLocation::from_str("79:1:1").unwrap(), + node_type: NewExpressionNodeType::NewExpression, + type_name: TypeName::UserDefinedTypeName(UserDefinedTypeName { + id: 1, + node_type: UserDefinedTypeNameNodeType::UserDefinedTypeName, + referenced_declaration: 5, + src: SourceLocation::from_str("82:1:1").unwrap(), + type_descriptions: TypeDescriptions { + type_string: Some("TestStruct".to_string()), + type_identifier: None, + }, + contract_scope: (), + name: Some("TestStruct".to_string()), + path_node: None, + }), + argument_types: None, + is_constant: false, + is_l_value: Some(false), + is_pure: false, + l_value_requested: false, + type_descriptions: TypeDescriptions { + type_string: Some("TestStruct".to_string()), + type_identifier: None, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_user_defined_type_name() -> UserDefinedTypeName { + UserDefinedTypeName { + id: 18, + node_type: UserDefinedTypeNameNodeType::UserDefinedTypeName, + referenced_declaration: 5, + src: SourceLocation::from_str("85:1:1").unwrap(), + type_descriptions: TypeDescriptions { + type_string: Some("TestStruct".to_string()), + type_identifier: None, + }, + contract_scope: (), + name: Some("TestStruct".to_string()), + path_node: None, + } +} + +#[allow(dead_code)] +pub fn create_test_identifier_path() -> IdentifierPath { + IdentifierPath { + id: 19, + node_type: IdentifierPathNodeType::IdentifierPath, + name: "IdPath".to_string(), + name_locations: vec![], + referenced_declaration: 15, + src: SourceLocation::from_str("88:1:1").unwrap(), + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file() -> SolidityAstFile { + // a source file with every possible InteractableNode + let source = "pragma solidity ^0.8.0; . + + contract Counter { + uint256 public number; + uint256 public x = 2; + uint256 public y = x; + + function notUsed() internal { + uint256 x = 1; + number; + } + }"; + + let path = "test.sol"; + + let mut function = create_test_function_definition(); + function.src = SourceLocation::from_str("240:86:0").unwrap(); + function + .modifiers + .push(create_test_modifier_invocation().into()); + function.body = Some(Block { + documentation: None, + id: 30, + node_type: BlockNodeType::Block, + src: SourceLocation::from_str("91:1:1").unwrap(), + statements: Some( + [ + Statement::ExpressionStatement(ExpressionStatement { + expression: create_test_function_call().into(), + id: 100, + node_type: ExpressionStatementNodeType::ExpressionStatement, + src: SourceLocation::from_str("94:1:1").unwrap(), + documentation: None, + }), + Statement::ExpressionStatement(ExpressionStatement { + expression: create_test_member_access().into(), + id: 102, + node_type: ExpressionStatementNodeType::ExpressionStatement, + src: SourceLocation::from_str("97:1:1").unwrap(), + documentation: None, + }), + Statement::ExpressionStatement(ExpressionStatement { + expression: create_test_new_expression().into(), + id: 103, + node_type: ExpressionStatementNodeType::ExpressionStatement, + src: SourceLocation::from_str("100:1:1").unwrap(), + documentation: None, + }), + ] + .iter() + .cloned() + .collect(), + ), + }); + let mut contract = create_test_contract_definition(); + contract.src = SourceLocation::from_str("121:211:0").unwrap(); + contract.nodes.push(function.into()); + contract + .nodes + .push(create_test_variable_declaration().into()); + contract.nodes.push(create_test_enum_definition().into()); + contract.nodes.push(create_test_struct_definition().into()); + contract.nodes.push(create_test_event_definition().into()); + contract + .nodes + .push(create_test_using_for_directive().into()); + contract.nodes.push(create_test_error_definition().into()); + + contract + .base_contracts + .push(create_test_inheritance_specifier().into()); + + let mut multiple_import = create_test_import_directive(); + multiple_import.unit_alias = "".to_string(); + multiple_import.symbol_aliases = vec![ImportDirectiveSymbolAliasesItem { + foreign: create_test_identifier(), + local: Some("TestLocal".to_string()), + name_location: None, + }]; + let mut empty_import = create_test_import_directive(); + empty_import.unit_alias = "".to_string(); + empty_import.symbol_aliases = vec![]; + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![ + contract.into(), + create_test_function_definition().into(), + create_test_variable_declaration().into(), + create_test_enum_definition().into(), + create_test_import_directive().into(), + multiple_import.into(), + empty_import.into(), + create_test_struct_definition().into(), + ], + src: SourceLocation::from_str("0:332:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +// create a function for each InteractableNode returning an ast file like the create_test_ast_file but with only the needed nodes + +#[allow(dead_code)] +pub fn create_test_ast_file_contract_definition() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + contract Counter { + }"; + + let path = "test.sol"; + + let mut contract = create_test_contract_definition(); + contract.src = SourceLocation::from_str("121:24:0").unwrap(); // index:range:0 (index is the start of the range) + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![contract.into()], + src: SourceLocation::from_str("0:145:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_function_definition() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + function notUsed() internal { + }"; + + let path = "test.sol"; + + let mut function = create_test_function_definition(); + function.src = SourceLocation::from_str("121:35:0").unwrap(); + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![function.into()], + src: SourceLocation::from_str("0:155:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_variable_declaration() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + uint256 public number; + "; + + let path = "test.sol"; + + let mut variable = create_test_variable_declaration(); + variable.src = SourceLocation::from_str("121:21:0").unwrap(); + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![variable.into()], + src: SourceLocation::from_str("0:148:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_enum_definition() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + enum TestEnum { + TestEnumValue + } + "; + + let path = "test.sol"; + + let mut enum_ = create_test_enum_definition(); + enum_.src = SourceLocation::from_str("121:35:0").unwrap(); + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![enum_.into()], + src: SourceLocation::from_str("0:169:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_enum_value() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + enum TestEnum { + TestEnumValue + } + "; + + let path = "test.sol"; + + let mut enum_value = create_test_enum_definition(); + enum_value.src = SourceLocation::from_str("121:43:0").unwrap(); + let mut value = create_test_enum_value(); + value.src = SourceLocation::from_str("145:13:0").unwrap(); + enum_value.members = vec![value]; + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![enum_value.into()], + src: SourceLocation::from_str("0:169:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_struct_definition() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + struct TestStruct { + } + "; + + let path = "test.sol"; + + let mut struct_ = create_test_struct_definition(); + struct_.src = SourceLocation::from_str("121:35:0").unwrap(); + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![struct_.into()], + src: SourceLocation::from_str("0:151:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_event_definition() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + contract Test { + event TestEvent() {} + } + "; + + let path = "test.sol"; + + let mut contract = create_test_contract_definition(); + contract.src = SourceLocation::from_str("121:50:0").unwrap(); + let mut event = create_test_event_definition(); + event.src = SourceLocation::from_str("145:20:0").unwrap(); + contract.nodes.push(event.into()); + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![contract.into()], + src: SourceLocation::from_str("0:176:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_using_for_directive() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + contract Test { + using for uint256; + } + "; + + let path = "test.sol"; + + let mut contract = create_test_contract_definition(); + contract.src = SourceLocation::from_str("121:48:0").unwrap(); + let mut using_for = create_test_using_for_directive(); + using_for.src = SourceLocation::from_str("145:18:0").unwrap(); + contract.nodes.push(using_for.into()); + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![contract.into()], + src: SourceLocation::from_str("0:174:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_error_definition() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + error TestError { + } + "; + + let path = "test.sol"; + + let mut error = create_test_error_definition(); + error.src = SourceLocation::from_str("121:23:0").unwrap(); + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![error.into()], + src: SourceLocation::from_str("0:149:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_function_call() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + contract Test { + function test() { + notUsed(); + } + } + "; + + let path = "test.sol"; + + let mut contract = create_test_contract_definition(); + contract.src = SourceLocation::from_str("121:80:0").unwrap(); + let mut function = create_test_function_definition(); + function.src = SourceLocation::from_str("145:50:0").unwrap(); + function.body = Some(Block { + documentation: None, + id: 30, + node_type: BlockNodeType::Block, + src: SourceLocation::from_str("162:32:0").unwrap(), + statements: Some( + [Statement::ExpressionStatement(ExpressionStatement { + expression: create_test_function_call().into(), + id: 100, + node_type: ExpressionStatementNodeType::ExpressionStatement, + src: SourceLocation::from_str("175:10:0").unwrap(), + documentation: None, + })] + .iter() + .cloned() + .collect(), + ), + }); + contract.nodes.push(function.into()); + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![contract.into()], + src: SourceLocation::from_str("0:206:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_modifier_invocation() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + contract Test { + function testFunc() test {} + } + "; + + let path = "test.sol"; + + let mut contract = create_test_contract_definition(); + contract.src = SourceLocation::from_str("121:57:0").unwrap(); + let mut function = create_test_function_definition(); + function.src = SourceLocation::from_str("145:27:0").unwrap(); + let mut modifier = create_test_modifier_invocation(); + modifier.src = SourceLocation::from_str("165:4:0").unwrap(); + function.modifiers.push(modifier); + contract.nodes.push(function.into()); + + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![contract.into()], + src: SourceLocation::from_str("0:177:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_inheritance_specifier() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + contract Test is TestStruct { + } + "; + + let path = "test.sol"; + + let mut contract = create_test_contract_definition(); + contract.src = SourceLocation::from_str("121:35:0").unwrap(); + let mut inheritance = create_test_inheritance_specifier(); + inheritance.src = SourceLocation::from_str("138:10:0").unwrap(); + contract.base_contracts.push(inheritance); + + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![contract.into()], + src: SourceLocation::from_str("0:161:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_identifier() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + contract Test { + uint256 number; + function test() { + number; + } + } + "; + + let path = "test.sol"; + + let mut contract = create_test_contract_definition(); + contract.src = SourceLocation::from_str("121:80:0").unwrap(); + let mut function = create_test_function_definition(); + function.src = SourceLocation::from_str("145:50:0").unwrap(); + function.body = Some(Block { + documentation: None, + id: 30, + node_type: BlockNodeType::Block, + src: SourceLocation::from_str("162:32:0").unwrap(), + statements: Some( + [Statement::ExpressionStatement(ExpressionStatement { + expression: create_test_identifier().into(), + id: 100, + node_type: ExpressionStatementNodeType::ExpressionStatement, + src: SourceLocation::from_str("175:10:0").unwrap(), + documentation: None, + })] + .iter() + .cloned() + .collect(), + ), + }); + contract.nodes.push(function.into()); + contract + .nodes + .push(create_test_variable_declaration().into()); + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![contract.into()], + src: SourceLocation::from_str("0:206:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_member_access() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + contract Test { + uint256 number; + function test() { + number.member; + } + } + "; + + let path = "test.sol"; + + let mut contract = create_test_contract_definition(); + contract.src = SourceLocation::from_str("121:108:0").unwrap(); + let mut member_access = create_test_member_access(); + member_access.src = SourceLocation::from_str("205:7:0").unwrap(); + let mut function = create_test_function_definition(); + function.src = SourceLocation::from_str("169:54:0").unwrap(); + function.body = Some(Block { + documentation: None, + id: 30, + node_type: BlockNodeType::Block, + src: SourceLocation::from_str("185:38:0").unwrap(), + statements: Some( + [Statement::ExpressionStatement(ExpressionStatement { + expression: member_access.into(), + id: 100, + node_type: ExpressionStatementNodeType::ExpressionStatement, + src: SourceLocation::from_str("205:7:0").unwrap(), + documentation: None, + })] + .iter() + .cloned() + .collect(), + ), + }); + contract.nodes.push(function.into()); + contract + .nodes + .push(create_test_variable_declaration().into()); + + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![contract.into()], + src: SourceLocation::from_str("0:234:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_new_expression() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + contract Test { + uint256 number; + function test() { + new TestStruct(); + } + } + "; + + let path = "test.sol"; + + let mut contract = create_test_contract_definition(); + contract.src = SourceLocation::from_str("121:111:0").unwrap(); + let mut new_expression = create_test_new_expression(); + new_expression.src = SourceLocation::from_str("199:16:0").unwrap(); + let mut function = create_test_function_definition(); + function.src = SourceLocation::from_str("169:57:0").unwrap(); + function.body = Some(Block { + documentation: None, + id: 30, + node_type: BlockNodeType::Block, + src: SourceLocation::from_str("185:41:0").unwrap(), + statements: Some( + [Statement::ExpressionStatement(ExpressionStatement { + expression: new_expression.into(), + id: 100, + node_type: ExpressionStatementNodeType::ExpressionStatement, + src: SourceLocation::from_str("175:10:0").unwrap(), + documentation: None, + })] + .iter() + .cloned() + .collect(), + ), + }); + contract.nodes.push(function.into()); + contract + .nodes + .push(create_test_variable_declaration().into()); + + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![contract.into()], + src: SourceLocation::from_str("0:237:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_user_defined_type_name() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + contract Test { + uint256 number; + function test() { + TestStruct testStruct = new TestStruct(); + } + } + "; + + let path = "test.sol"; + + let mut contract = create_test_contract_definition(); + contract.src = SourceLocation::from_str("121:111:0").unwrap(); + let mut new_expression = create_test_new_expression(); + new_expression.src = SourceLocation::from_str("223:16:0").unwrap(); + let mut user_defined_type_name = create_test_user_defined_type_name(); + user_defined_type_name.src = SourceLocation::from_str("227:10:0").unwrap(); + new_expression.type_name = TypeName::UserDefinedTypeName(user_defined_type_name.clone()); + let mut function = create_test_function_definition(); + function.src = SourceLocation::from_str("169:81:0").unwrap(); + function.body = Some(Block { + documentation: None, + id: 30, + node_type: BlockNodeType::Block, + src: SourceLocation::from_str("185:65:0").unwrap(), + statements: Some( + [Statement::ExpressionStatement(ExpressionStatement { + expression: new_expression.into(), + id: 100, + node_type: ExpressionStatementNodeType::ExpressionStatement, + src: SourceLocation::from_str("223:16:0").unwrap(), + documentation: None, + })] + .iter() + .cloned() + .collect(), + ), + }); + contract.nodes.push(function.into()); + contract + .nodes + .push(create_test_variable_declaration().into()); + + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![contract.into()], + src: SourceLocation::from_str("0:261:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_modifier_definition() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + contract Test { + modifier modifier() {} + } + "; + + let path = "test.sol"; + + let mut contract = create_test_contract_definition(); + contract.src = SourceLocation::from_str("121:35:0").unwrap(); + let mut modifier = create_test_modifier_definition(); + modifier.src = SourceLocation::from_str("145:22:0").unwrap(); + contract.nodes.push(modifier.into()); + + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![contract.into()], + src: SourceLocation::from_str("0:261:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_identifier_path() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + contract Test is TestStruct { + } + "; + + let path = "test.sol"; + + let mut contract = create_test_contract_definition(); + contract.src = SourceLocation::from_str("121:35:0").unwrap(); + let mut identifier_path = create_test_identifier_path(); + identifier_path.src = SourceLocation::from_str("138:10:0").unwrap(); + let mut inheritance = create_test_inheritance_specifier(); + inheritance.src = SourceLocation::from_str("138:10:0").unwrap(); + inheritance.base_name = InheritanceSpecifierBaseName::IdentifierPath(identifier_path.clone()); + contract.base_contracts.push(inheritance); + + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![contract.into()], + src: SourceLocation::from_str("0:161:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} + +#[allow(dead_code)] +pub fn create_test_ast_file_import_directive() -> SolidityAstFile { + let source = "pragma solidity ^0.8.0; . + + import 'test.sol' as Alias; + "; + + let path = "test.sol"; + + let mut import = create_test_import_directive(); + import.src = SourceLocation::from_str("121:21:0").unwrap(); + SolidityAstFile { + file: osmium_libs_solidity_ast_extractor::types::SolidityFile { + path: path.to_string(), + content: source.to_string(), + }, + ast: SourceUnit { + id: 0, + nodes: vec![import.into()], + src: SourceLocation::from_str("0:161:0").unwrap(), + absolute_path: "/home/user/test.sol".to_string(), + experimental_solidity: None, + exported_symbols: None, + license: None, + node_type: SourceUnitNodeType::SourceUnit, + }, + } +} diff --git a/libs/code-actions/src/types.rs b/libs/code-actions/src/types.rs index da8fc2b..9eee2d6 100644 --- a/libs/code-actions/src/types.rs +++ b/libs/code-actions/src/types.rs @@ -1,7 +1,7 @@ use crate::utils::source_location_to_range; use solc_ast_rs_types::types::*; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] pub struct Position { pub line: u32, pub column: u32, @@ -13,13 +13,13 @@ impl Default for Position { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Range { pub index: u32, pub length: u32, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Location { pub start: Position, pub end: Position, @@ -207,6 +207,7 @@ impl InteractableNode { InteractableNode::UserDefinedTypeName(node) => Some(node.referenced_declaration), InteractableNode::Identifier(node) => node.referenced_declaration, InteractableNode::MemberAccess(node) => node.referenced_declaration, + InteractableNode::IdentifierPath(node) => Some(node.referenced_declaration), _ => None, } } @@ -237,9 +238,7 @@ impl InteractableNode { InteractableNode::EnumValue(node) => source_location_to_range( node.name_location.as_ref().unwrap_or(&node.src.to_owned()), ), - InteractableNode::ErrorDefinition(node) => { - source_location_to_range(&node.name_location) - } + InteractableNode::ErrorDefinition(node) => source_location_to_range(&node.src), InteractableNode::UsingForDirective(node) => source_location_to_range(&node.src), InteractableNode::ImportDirective(node) => source_location_to_range(&node.src), InteractableNode::FunctionCall(node) => source_location_to_range(&node.src), @@ -259,3 +258,482 @@ impl InteractableNode { } } } + +#[cfg(test)] +mod test { + + use crate::test_utils::*; + + use super::*; + + #[test] + fn test_get_contract_id() { + let node = InteractableNode::ContractDefinition(create_test_contract_definition()); + assert_eq!(node.get_id(), 1); + } + + #[test] + fn test_get_function_definition_id() { + let node = InteractableNode::FunctionDefinition(create_test_function_definition()); + assert_eq!(node.get_id(), 2); + } + + #[test] + fn test_get_variable_declaration_id() { + let node = InteractableNode::VariableDeclaration(create_test_variable_declaration()); + assert_eq!(node.get_id(), 3); + } + + #[test] + fn test_get_modifier_definition_id() { + let node = InteractableNode::ModifierDefinition(create_test_modifier_definition()); + assert_eq!(node.get_id(), 4); + } + + #[test] + fn test_get_struct_definition_id() { + let node = InteractableNode::StructDefinition(create_test_struct_definition()); + assert_eq!(node.get_id(), 5); + } + + #[test] + fn test_get_enum_definition_id() { + let node = InteractableNode::EnumDefinition(create_test_enum_definition()); + assert_eq!(node.get_id(), 6); + } + + #[test] + fn test_get_event_definition_id() { + let node = InteractableNode::EventDefinition(create_test_event_definition()); + assert_eq!(node.get_id(), 7); + } + + #[test] + fn test_get_enum_value_id() { + let node = InteractableNode::EnumValue(create_test_enum_value()); + assert_eq!(node.get_id(), 8); + } + + #[test] + fn test_get_using_for_directive_id() { + let node = InteractableNode::UsingForDirective(create_test_using_for_directive()); + assert_eq!(node.get_id(), 9); + } + + #[test] + fn test_get_import_directive_id() { + let node = InteractableNode::ImportDirective(create_test_import_directive()); + assert_eq!(node.get_id(), 10); + } + + #[test] + fn test_get_error_definition_id() { + let node = InteractableNode::ErrorDefinition(create_test_error_definition()); + assert_eq!(node.get_id(), 11); + } + + #[test] + fn test_get_function_call_id() { + let node = InteractableNode::FunctionCall(create_test_function_call()); + assert_eq!(node.get_id(), 12); + } + + #[test] + fn test_get_modifier_invocation_id() { + let node = InteractableNode::ModifierInvocation(create_test_modifier_invocation()); + assert_eq!(node.get_id(), 13); + } + + #[test] + fn test_get_inheritance_specifier_id() { + let node = InteractableNode::InheritanceSpecifier(create_test_inheritance_specifier()); + assert_eq!(node.get_id(), 14); + } + + #[test] + fn test_get_identifier_id() { + let node = InteractableNode::Identifier(create_test_identifier()); + assert_eq!(node.get_id(), 15); + } + + #[test] + fn test_get_member_access_id() { + let node = InteractableNode::MemberAccess(create_test_member_access()); + assert_eq!(node.get_id(), 16); + } + + #[test] + fn test_get_new_expression_id() { + let node = InteractableNode::NewExpression( + create_test_new_expression(), + Box::new(InteractableNode::VariableDeclaration( + create_test_variable_declaration(), + )), + ); + assert_eq!(node.get_id(), 17); + } + + #[test] + fn test_get_user_defined_type_name_id() { + let node = InteractableNode::UserDefinedTypeName(create_test_user_defined_type_name()); + assert_eq!(node.get_id(), 18); + } + + #[test] + fn test_get_identifier_path_id() { + let node = InteractableNode::IdentifierPath(create_test_identifier_path()); + assert_eq!(node.get_id(), 19); + } + + #[test] + fn test_get_contract_ref_id() { + let node = InteractableNode::ContractDefinition(create_test_contract_definition()); + assert_eq!(node.get_reference_id(), None); + } + + #[test] + fn test_get_function_definition_ref_id() { + let node = InteractableNode::FunctionDefinition(create_test_function_definition()); + assert_eq!(node.get_reference_id(), None); + } + + #[test] + fn test_get_variable_declaration_ref_id() { + let node = InteractableNode::VariableDeclaration(create_test_variable_declaration()); + assert_eq!(node.get_reference_id(), Some(3)); + } + + #[test] + fn test_get_modifier_definition_ref_id() { + let node = InteractableNode::ModifierDefinition(create_test_modifier_definition()); + assert_eq!(node.get_reference_id(), None); + } + + #[test] + fn test_get_struct_definition_ref_id() { + let node = InteractableNode::StructDefinition(create_test_struct_definition()); + assert_eq!(node.get_reference_id(), None); + } + + #[test] + fn test_get_enum_definition_ref_id() { + let node = InteractableNode::EnumDefinition(create_test_enum_definition()); + assert_eq!(node.get_reference_id(), None); + } + + #[test] + fn test_get_event_definition_ref_id() { + let node = InteractableNode::EventDefinition(create_test_event_definition()); + assert_eq!(node.get_reference_id(), None); + } + + #[test] + fn test_get_enum_value_ref_id() { + let node = InteractableNode::EnumValue(create_test_enum_value()); + assert_eq!(node.get_reference_id(), None); + } + + #[test] + fn test_get_using_for_directive_ref_id() { + let node = InteractableNode::UsingForDirective(create_test_using_for_directive()); + assert_eq!(node.get_reference_id(), None); + } + + #[test] + fn test_get_import_directive_ref_id() { + let node = InteractableNode::ImportDirective(create_test_import_directive()); + assert_eq!(node.get_reference_id(), None); + } + + #[test] + fn test_get_error_definition_ref_id() { + let node = InteractableNode::ErrorDefinition(create_test_error_definition()); + assert_eq!(node.get_reference_id(), None); + } + + #[test] + fn test_get_function_call_ref_id() { + let node = InteractableNode::FunctionCall(create_test_function_call()); + assert_eq!(node.get_reference_id(), None); + } + + #[test] + fn test_get_modifier_invocation_ref_id() { + let node = InteractableNode::ModifierInvocation(create_test_modifier_invocation()); + assert_eq!(node.get_reference_id(), None); + } + + #[test] + fn test_get_inheritance_specifier_ref_id() { + let node = InteractableNode::InheritanceSpecifier(create_test_inheritance_specifier()); + assert_eq!(node.get_reference_id(), Some(5)); + } + + #[test] + fn test_get_identifier_ref_id() { + let node = InteractableNode::Identifier(create_test_identifier()); + assert_eq!(node.get_reference_id(), Some(30)); + } + + #[test] + fn test_get_member_access_ref_id() { + let node = InteractableNode::MemberAccess(create_test_member_access()); + assert_eq!(node.get_reference_id(), Some(3)); + } + + #[test] + fn test_get_new_expression_ref_id() { + let node = InteractableNode::NewExpression( + create_test_new_expression(), + Box::new(InteractableNode::VariableDeclaration( + create_test_variable_declaration(), + )), + ); + assert_eq!(node.get_reference_id(), None); + } + + #[test] + fn test_get_user_defined_type_name_ref_id() { + let node = InteractableNode::UserDefinedTypeName(create_test_user_defined_type_name()); + assert_eq!(node.get_reference_id(), Some(5)); + } + + #[test] + fn test_get_identifier_path_ref_id() { + let node = InteractableNode::IdentifierPath(create_test_identifier_path()); + assert_eq!(node.get_reference_id(), Some(15)); + } + + #[test] + fn test_get_contract_definition_range() { + let node = InteractableNode::ContractDefinition(create_test_contract_definition()); + assert_eq!( + node.get_range(), + Range { + index: 29, + length: 215 + } + ); + } + + #[test] + fn test_get_function_definition_range() { + let node = InteractableNode::FunctionDefinition(create_test_function_definition()); + assert_eq!( + node.get_range(), + Range { + index: 152, + length: 86 + } + ); + } + + #[test] + fn test_get_variable_declaration_range() { + let node = InteractableNode::VariableDeclaration(create_test_variable_declaration()); + assert_eq!( + node.get_range(), + Range { + index: 56, + length: 21 + } + ); + } + + #[test] + fn test_get_modifier_definition_range() { + let node = InteractableNode::ModifierDefinition(create_test_modifier_definition()); + assert_eq!( + node.get_range(), + Range { + index: 16, + length: 1 + } + ); + } + + #[test] + fn test_get_struct_definition_range() { + let node = InteractableNode::StructDefinition(create_test_struct_definition()); + assert_eq!( + node.get_range(), + Range { + index: 25, + length: 1 + } + ); + } + + #[test] + fn test_get_enum_definition_range() { + let node = InteractableNode::EnumDefinition(create_test_enum_definition()); + assert_eq!( + node.get_range(), + Range { + index: 28, + length: 1 + } + ); + } + + #[test] + fn test_get_event_definition_range() { + let node = InteractableNode::EventDefinition(create_test_event_definition()); + assert_eq!( + node.get_range(), + Range { + index: 31, + length: 1 + } + ); + } + + #[test] + fn test_get_enum_value_range() { + let node = InteractableNode::EnumValue(create_test_enum_value()); + assert_eq!( + node.get_range(), + Range { + index: 37, + length: 1 + } + ); + } + + #[test] + fn test_get_using_for_directive_range() { + let node = InteractableNode::UsingForDirective(create_test_using_for_directive()); + assert_eq!( + node.get_range(), + Range { + index: 40, + length: 1 + } + ); + } + + #[test] + fn test_get_import_directive_range() { + let node = InteractableNode::ImportDirective(create_test_import_directive()); + assert_eq!( + node.get_range(), + Range { + index: 43, + length: 1 + } + ); + } + + #[test] + fn test_get_error_definition_range() { + let node = InteractableNode::ErrorDefinition(create_test_error_definition()); + assert_eq!( + node.get_range(), + Range { + index: 46, + length: 1 + } + ); + } + + #[test] + fn test_get_function_call_range() { + let node = InteractableNode::FunctionCall(create_test_function_call()); + assert_eq!( + node.get_range(), + Range { + index: 175, + length: 10 + } + ); + } + + #[test] + fn test_get_modifier_invocation_range() { + let node = InteractableNode::ModifierInvocation(create_test_modifier_invocation()); + assert_eq!( + node.get_range(), + Range { + index: 58, + length: 1 + } + ); + } + + #[test] + fn test_get_inheritance_specifier_range() { + let node = InteractableNode::InheritanceSpecifier(create_test_inheritance_specifier()); + assert_eq!( + node.get_range(), + Range { + index: 64, + length: 1 + } + ); + } + + #[test] + fn test_get_identifier_range() { + let node = InteractableNode::Identifier(create_test_identifier()); + assert_eq!( + node.get_range(), + Range { + index: 199, + length: 6 + } + ); + } + + #[test] + fn test_get_member_access_range() { + let node = InteractableNode::MemberAccess(create_test_member_access()); + assert_eq!( + node.get_range(), + Range { + index: 73, + length: 1 + } + ); + } + + #[test] + fn test_get_new_expression_range() { + let node = InteractableNode::NewExpression( + create_test_new_expression(), + Box::new(InteractableNode::VariableDeclaration( + create_test_variable_declaration(), + )), + ); + assert_eq!( + node.get_range(), + Range { + index: 79, + length: 1 + } + ); + } + + #[test] + fn test_get_user_defined_type_name_range() { + let node = InteractableNode::UserDefinedTypeName(create_test_user_defined_type_name()); + assert_eq!( + node.get_range(), + Range { + index: 85, + length: 1 + } + ); + } + + #[test] + fn test_get_identifier_path_range() { + let node = InteractableNode::IdentifierPath(create_test_identifier_path()); + assert_eq!( + node.get_range(), + Range { + index: 88, + length: 1 + } + ); + } +} diff --git a/libs/code-actions/src/utils.rs b/libs/code-actions/src/utils.rs index 867a606..5b265ff 100644 --- a/libs/code-actions/src/utils.rs +++ b/libs/code-actions/src/utils.rs @@ -2,9 +2,8 @@ use crate::{ types::{InteractableNode, Position, Range}, Location, }; -use log::info; use osmium_libs_solidity_ast_extractor::types::SolidityAstFile; -use solc_ast_rs_types::types::SourceLocation; +use solc_ast_rs_types::types::*; pub fn is_node_in_range(node: &SourceLocation, position: &Position, source: &str) -> bool { let range = source_location_to_range(node); @@ -21,10 +20,10 @@ pub fn log_is_node_in_range(node: &SourceLocation, position: &Position, source: let range = source_location_to_range(node); let index = position_to_index(position, source); - info!("Node Range: {:?}", range); - info!("Position: {:?}", position); - info!("Position Index: {:?}", index); - info!("Source: {:?}", source); + eprintln!("Node Range: {:?}", range); + eprintln!("Position: {:?}", position); + eprintln!("Position Index: {:?}", index); + eprintln!("Source: {:?}", source); if range.index <= index && range.index + range.length >= index { return true; } @@ -87,8 +86,47 @@ pub fn get_location(node: &InteractableNode, file: &SolidityAstFile) -> Location #[cfg(test)] mod test { + use std::str::FromStr; + + use crate::{test_utils::create_test_ast_file, utils::index_to_position}; + pub use super::*; + #[test] + fn test_get_location() { + let file = create_test_ast_file(); + let node = file.ast.nodes[0].clone(); + if let SourceUnitNodesItem::ContractDefinition(node) = node { + let node = InteractableNode::ContractDefinition(node); + let location = get_location(&node, &file); + let expected_location = Location { + start: Position { line: 3, column: 5 }, + end: Position { + line: 12, + column: 6, + }, + uri: "test.sol".to_string(), + }; + assert_eq!(location, expected_location); + } else { + panic!("Expected ContractDefinition"); + } + } + + #[test] + fn test_log_is_node_in_range() { + let file = create_test_ast_file(); + let node = file.ast.nodes[0].clone(); + if let SourceUnitNodesItem::ContractDefinition(contract) = node { + let position = Position { line: 3, column: 5 }; + let source = &file.file.content; + let is_in_range = log_is_node_in_range(&contract.src, &position, source); + assert_eq!(is_in_range, true); + } else { + panic!("Expected ContractDefinition"); + } + } + #[test] fn postion_to_index_when_position_not_matched() { let source = "pragma solidity ^0.8.0; @@ -122,4 +160,265 @@ contract Counter { let expected_idx = 240; assert_eq!(index, expected_idx); } + + #[test] + fn postion_to_index_when_position_matched() { + let source = "pragma solidity ^0.8.0; + +contract Counter { + uint256 public number; + uint256 public x = 2; + uint256 public y = x; + + function setNumber(uint256 newNumber) public + { + tx.origin; + number = newNumber + y; + d + } + + function increment() public { + setNumber(number + 1); + } + + function notUsed() internal { + uint256 x = 1; + number; + } +}"; + let position = Position { + line: 12, + column: 10, + }; + let index = position_to_index(&position, source); + let expected_idx = 240; + assert_eq!(index, expected_idx); + } + + #[test] + fn postion_to_index_start_of_file() { + let source = "pragma solidity ^0.8.0; + +contract Counter { + uint256 public number; + uint256 public x = 2; + uint256 public y = x; + + function setNumber(uint256 newNumber) public + { + tx.origin; + number = newNumber + y; + d + } + + function increment() public { + setNumber(number + 1); + } + + function notUsed() internal { + uint256 x = 1; + number; + } +}"; + let position = Position { line: 0, column: 1 }; + let index = position_to_index(&position, source); + let expected_idx = 0; + assert_eq!(index, expected_idx); + } + + #[test] + fn test_index_to_position() { + let source = "pragma solidity ^0.8.0; + +contract Counter { + uint256 public number; + uint256 public x = 2; + uint256 public y = x; + + function setNumber(uint256 newNumber) public + { + tx.origin; + number = newNumber + y; + d + } + + function increment() public { + setNumber(number + 1); + } + + function notUsed() internal { + uint256 x = 1; + number; + } +}"; + let index = 240; + let position = index_to_position(index, source); + let expected_position = Position { + line: 12, + column: 10, + }; + assert_eq!(position, expected_position); + } + + #[test] + fn test_source_location_to_range() { + let location = "240:1"; + let range = source_location_to_range(location); + let expected_range = Range { + index: 240, + length: 1, + }; + assert_eq!(range, expected_range); + } + + #[test] + fn test_is_node_in_range() { + let source = "pragma solidity ^0.8.0; + +contract Counter { + uint256 public number; + uint256 public x = 2; + uint256 public y = x; + + function setNumber(uint256 newNumber) public + { + tx.origin; + number = newNumber + y; + d + } + + function increment() public { + setNumber(number + 1); + } + + function notUsed() internal { + uint256 x = 1; + number; + } +}"; + let position = Position { + line: 12, + column: 10, + }; + let location = "240:1:1"; + let src_location = SourceLocation::from_str(location); + let is_in_range = is_node_in_range(&src_location.unwrap(), &position, source); + assert_eq!(is_in_range, true); + } + + #[test] + fn test_is_node_in_range_when_not_in_range() { + let source = "pragma solidity ^0.8.0; + +contract Counter { + uint256 public number; + uint256 public x = 2; + uint256 public y = x; + + function setNumber(uint256 newNumber) public + { + tx.origin; + number = newNumber + y; + d + } + + function increment() public { + setNumber(number + 1); + } + + function notUsed() internal { + uint256 x = 1; + number; + } +}"; + + let position = Position { + line: 12, + column: 10, + }; + let location = "210:1:1"; + let src_location = SourceLocation::from_str(location); + let is_in_range = is_node_in_range(&src_location.unwrap(), &position, source); + assert_eq!(is_in_range, false); + } + + #[test] + fn test_is_node_in_range_when_not_in_range_with_empty_source() { + let source = ""; + let position = Position { + line: 12, + column: 10, + }; + let location = "210:1:1"; + let src_location = SourceLocation::from_str(location); + let is_in_range = is_node_in_range(&src_location.unwrap(), &position, source); + assert_eq!(is_in_range, false); + } + + #[test] + fn test_is_node_in_range_when_not_in_range_with_empty_location() { + let source = "pragma solidity ^0.8.0; + +contract Counter { + uint256 public number; + uint256 public x = 2; + uint256 public y = x; + + function setNumber(uint256 newNumber) public + { + tx.origin; + number = newNumber + y; + d + } + + function increment() public { + setNumber(number + 1); + } + + function notUsed() internal { + uint256 x = 1; + number; + } +}"; + let position = Position { + line: 12, + column: 10, + }; + let location = "0:0:0"; + let src_location = SourceLocation::from_str(location); + let is_in_range = is_node_in_range(&src_location.unwrap(), &position, source); + assert_eq!(is_in_range, false); + } + + #[test] + fn test_is_node_in_range_when_not_in_range_with_empty_position() { + let source = "pragma solidity ^0.8.0; + +contract Counter { + uint256 public number; + uint256 public x = 2; + uint256 public y = x; + + function setNumber(uint256 newNumber) public + { + tx.origin; + number = newNumber + y; + d + } + + function increment() public { + setNumber(number + 1); + } + + function notUsed() internal { + uint256 x = 1; + number; + } +}"; + let position = Position { line: 0, column: 0 }; + let location = "240:1:1"; + let src_location = SourceLocation::from_str(location); + let is_in_range = is_node_in_range(&src_location.unwrap(), &position, source); + assert_eq!(is_in_range, false); + } } diff --git a/libs/path-utils/src/lib.rs b/libs/path-utils/src/lib.rs index d2fb7aa..7e9feb7 100644 --- a/libs/path-utils/src/lib.rs +++ b/libs/path-utils/src/lib.rs @@ -52,3 +52,54 @@ pub fn escape_path(path: &str) -> String { pub fn escape_path(path: &str) -> String { path.to_string() } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_normalize_path() { + let path = "/c:/Users/username/Documents"; + assert_eq!(normalize_path(path), "c:/Users/username/Documents"); + } + + #[test] + fn test_normalize_path_windows() { + let path = "/c%3A/Users/username/Documents"; + assert_eq!(normalize_path(path), "c:/Users/username/Documents"); + } + + #[test] + fn test_join_path() { + let base_path = "C:/Users/username/Documents"; + let file = "file.sol"; + assert_eq!( + join_path(base_path, file), + "C:/Users/username/Documents/file.sol" + ); + } + + #[test] + fn test_slashify_path() { + let path = "C:\\Users\\username\\Documents"; + assert_eq!(slashify_path(path), "C:/Users/username/Documents"); + } + + #[test] + fn test_slashify_path_double_slash() { + let path = "C:\\Users\\\\username\\Documents"; + assert_eq!(slashify_path(path), "C:/Users/username/Documents"); + } + + #[test] + fn test_escape_path() { + let path = "c://Users/username/Documents"; + assert_eq!(escape_path(path), "/c%3A/Users/username/Documents"); + } + + #[test] + fn test_escape_path_windows() { + let path = "c://Users/username/Documents"; + assert_eq!(escape_path(path), "/c%3A/Users/username/Documents"); + } +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 05c2a90..36ee456 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -2536,9 +2536,6 @@ packages: string_decoder@1.1.1: resolution: {integrity: sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==} - string_decoder@1.3.0: - resolution: {integrity: sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==} - strip-ansi@6.0.1: resolution: {integrity: sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==} engines: {node: '>=8'} @@ -5155,7 +5152,7 @@ snapshots: readable-stream@3.6.2: dependencies: inherits: 2.0.4 - string_decoder: 1.3.0 + string_decoder: 1.1.1 util-deprecate: 1.0.2 readdirp@3.6.0: @@ -5330,10 +5327,6 @@ snapshots: dependencies: safe-buffer: 5.1.2 - string_decoder@1.3.0: - dependencies: - safe-buffer: 5.2.1 - strip-ansi@6.0.1: dependencies: ansi-regex: 5.0.1 diff --git a/remove-me-8e99d6014e4246b5b7f0.txt b/remove-me-8e99d6014e4246b5b7f0.txt new file mode 100644 index 0000000..c368505 --- /dev/null +++ b/remove-me-8e99d6014e4246b5b7f0.txt @@ -0,0 +1 @@ +8e99d6014e4246b5b7f0 diff --git a/remove-me-d246d5e2e5664fa2a574.txt b/remove-me-d246d5e2e5664fa2a574.txt new file mode 100644 index 0000000..99df92a --- /dev/null +++ b/remove-me-d246d5e2e5664fa2a574.txt @@ -0,0 +1 @@ +d246d5e2e5664fa2a574 diff --git a/servers/slither-server/src/slither.rs b/servers/slither-server/src/slither.rs index a271398..00eadff 100644 --- a/servers/slither-server/src/slither.rs +++ b/servers/slither-server/src/slither.rs @@ -66,8 +66,7 @@ async fn get_slither_error(uri: &str, workspace: &str) -> Result<(), SlitherErro output.wait().await?; errbuffer.read_to_string(&mut errdst).await?; - if !errdst.is_empty() - && errdst.contains("Error: Source file requires different compiler version") + if errdst.len() > 0 && errdst.contains("Error: Source file requires different compiler version") { let regex = Regex::new(r"(?m)(?:current compiler is.+\))").unwrap(); let match_ = regex.find(&errdst).unwrap().as_str(); @@ -76,7 +75,7 @@ async fn get_slither_error(uri: &str, workspace: &str) -> Result<(), SlitherErro "Slither needs a different version from the one specified in file: {}", match_ ))); - } else if !errdst.is_empty() && errdst.contains("Invalid option for --evm-version:") { + } else if errdst.len() > 0 && errdst.contains("Invalid option for --evm-version:") { return Err(SlitherError::Unknown("Please explicitly specify the evm version in the foundry.toml file to a compatible version of your solc compiler version".to_string())); } Ok(()) diff --git a/vscode/src/code-actions.ts b/vscode/src/code-actions.ts index 0d32abe..ca54f64 100644 --- a/vscode/src/code-actions.ts +++ b/vscode/src/code-actions.ts @@ -1,6 +1,6 @@ -import * as path from "path"; -import * as os from "os"; -import { workspace, ExtensionContext, Uri } from "vscode"; +import * as path from 'path'; +import * as os from 'os'; +import { workspace, ExtensionContext, Uri } from 'vscode'; import { LanguageClient, LanguageClientOptions, @@ -8,13 +8,11 @@ import { TransportKind, SocketTransport, StreamInfo, -} from "vscode-languageclient/node"; -import { TextDecoder } from "util"; -import * as net from "net"; +} from 'vscode-languageclient/node'; +import { TextDecoder } from 'util'; +import * as net from 'net'; -export async function createCodeActionsClient( - context: ExtensionContext, -): Promise { +export async function createCodeActionsClient(context: ExtensionContext): Promise { /* let connectionInfo = { port: 9001, @@ -39,12 +37,7 @@ export async function createCodeActionsClient( // The server is implemented in node const serverBinary = context.asAbsolutePath( - path.join( - "dist", - os.platform().startsWith("win") - ? "code-actions-server.exe" - : "code-actions-server", - ), + path.join('dist', os.platform().startsWith('win') ? 'code-actions-server.exe' : 'code-actions-server'), ); const serverOptions: ServerOptions = { @@ -58,17 +51,17 @@ export async function createCodeActionsClient( // Options to control the language client const clientOptions: LanguageClientOptions = { // Register the server for plain text documents - documentSelector: [{ scheme: "file", language: "solidity" }], + documentSelector: [{ scheme: 'file', language: 'solidity' }], synchronize: { // Notify the server about file changes to '.clientrc files contained in the workspace - fileEvents: workspace.createFileSystemWatcher("**/.solidhunter.json"), + fileEvents: workspace.createFileSystemWatcher('**/.solidhunter.json'), }, }; // Create the language client and start the client. const client = new LanguageClient( - "osmium-solidity-code-actions", - "Osmium Solidity Code Actions Language Server", + 'osmium-solidity-code-actions', + 'Osmium Solidity Code Actions Language Server', serverOptions, clientOptions, ); diff --git a/vscode/src/fmt-wrapper.ts b/vscode/src/fmt-wrapper.ts index 6ecf457..50892d1 100644 --- a/vscode/src/fmt-wrapper.ts +++ b/vscode/src/fmt-wrapper.ts @@ -154,7 +154,9 @@ function registerForgeFmtLinter(context: vscode.ExtensionContext): { forgeFmt(args) .then((result) => { - if (result.exitCode !== 0) { + if (result.exitCode === 0) { + vscode.window.showInformationMessage('Forge fmt ran successfully.'); + } else { vscode.window.showErrorMessage('Forge fmt failed. Please check the output for details.'); console.log(result.output); @@ -167,10 +169,10 @@ function registerForgeFmtLinter(context: vscode.ExtensionContext): { }); const formatter = vscode.languages.registerDocumentFormattingEditProvider('solidity', { - provideDocumentFormattingEdits: async (document) => { + provideDocumentFormattingEdits: (document) => { if (!isFmtInstalled()) { vscode.window.showErrorMessage('Forge fmt is not installed. Please install it and try again.'); - return []; + return; } const options: ForgeFmtOptions = { @@ -184,20 +186,17 @@ function registerForgeFmtLinter(context: vscode.ExtensionContext): { files: [document.fileName], }; - try { - await forgeFmt(args); - - // Read the formatted file - const formattedText = await vscode.workspace.fs.readFile(vscode.Uri.file(document.fileName)); - const fullRange = new vscode.Range(document.positionAt(0), document.positionAt(document.getText().length)); + return forgeFmt(args).then((result) => { + if (result.exitCode === 0) { + vscode.window.showInformationMessage('Forge fmt ran successfully.'); + } else { + vscode.window.showErrorMessage('Forge fmt failed. Please check the output for details.'); - return [vscode.TextEdit.replace(fullRange, Buffer.from(formattedText).toString('utf8'))]; - } catch (error) { - vscode.window.showErrorMessage('Forge fmt failed. Please check the output for details.'); - console.error(error); - } + console.log(result.output); + } - return []; + return []; + }); }, }); diff --git a/vscode/src/foundry-compiler.ts b/vscode/src/foundry-compiler.ts index b152144..b67c490 100644 --- a/vscode/src/foundry-compiler.ts +++ b/vscode/src/foundry-compiler.ts @@ -1,24 +1,12 @@ -import * as path from "path"; -import { workspace, ExtensionContext } from "vscode"; -import { - LanguageClient, - LanguageClientOptions, - ServerOptions, - TransportKind, -} from "vscode-languageclient/node"; -import * as os from "os"; +import * as path from 'path'; +import { workspace, ExtensionContext } from 'vscode'; +import { LanguageClient, LanguageClientOptions, ServerOptions, TransportKind } from 'vscode-languageclient/node'; +import * as os from 'os'; -export function createFoundryCompilerClient( - context: ExtensionContext, -): LanguageClient { +export function createFoundryCompilerClient(context: ExtensionContext): LanguageClient { // The server is implemented in node const serverBinary = context.asAbsolutePath( - path.join( - "dist", - os.platform().startsWith("win") - ? "foundry-compiler-server.exe" - : "foundry-compiler-server", - ), + path.join('dist', os.platform().startsWith('win') ? 'foundry-compiler-server.exe' : 'foundry-compiler-server'), ); // If the extension is launched in debug mode then the debug server options are used @@ -34,17 +22,17 @@ export function createFoundryCompilerClient( // Options to control the language client const clientOptions: LanguageClientOptions = { // Register the server for plain text documents - documentSelector: [{ scheme: "file", language: "solidity" }], + documentSelector: [{ scheme: 'file', language: 'solidity' }], synchronize: { // Notify the server about file changes to '.clientrc files contained in the workspace - fileEvents: workspace.createFileSystemWatcher("**/.solidhunter.json"), + fileEvents: workspace.createFileSystemWatcher('**/.solidhunter.json'), }, }; // Create the language client and start the client. const client = new LanguageClient( - "osmium-solidity-foundry-compiler", - "Osmium Solidity Foundry Compiler Language Server", + 'osmium-solidity-foundry-compiler', + 'Osmium Solidity Foundry Compiler Language Server', serverOptions, clientOptions, ); diff --git a/vscode/src/gas-estimation.ts b/vscode/src/gas-estimation.ts index 06db1f3..0ed850b 100644 --- a/vscode/src/gas-estimation.ts +++ b/vscode/src/gas-estimation.ts @@ -1,6 +1,6 @@ -import { execSync, exec } from "child_process"; -import { Disposable} from "vscode"; -import * as vscode from "vscode"; +import { execSync, exec } from 'child_process'; +import { Disposable } from 'vscode'; +import * as vscode from 'vscode'; type GasReport = { average: bigint; @@ -21,7 +21,7 @@ type ReportDecorators = Map; function isForgeInstalled(): boolean { try { - execSync("forge --version"); + execSync('forge --version'); return true; } catch (error) { return false; @@ -43,71 +43,62 @@ async function gasReportTests(cwd: string): Promise { resolve(); } - // pqrse the forge test --gas-report output to find contracts and functions - let contractName = ""; - await Promise.all( - _stdout.split("\n").map(async (line: string) => { - const lineParts = line.split("|"); - if (lineParts.length === 8) { - const trimmedLineParts = lineParts.map((part) => part.trim()); - if ( - trimmedLineParts[1] !== "" && - trimmedLineParts[2] === "" && - trimmedLineParts[3] === "" && - trimmedLineParts[4] === "" && - trimmedLineParts[5] === "" && - trimmedLineParts[6] === "" - ) { - contractName = trimmedLineParts[1].split(" ")[0]; - } + // pqrse the forge test --gas-report output to find contracts and functions + let contractName = ''; + await Promise.all( + _stdout.split('\n').map(async (line: string) => { + const lineParts = line.split('|'); + if (lineParts.length === 8) { + const trimmedLineParts = lineParts.map((part) => part.trim()); + if ( + trimmedLineParts[1] !== '' && + trimmedLineParts[2] === '' && + trimmedLineParts[3] === '' && + trimmedLineParts[4] === '' && + trimmedLineParts[5] === '' && + trimmedLineParts[6] === '' + ) { + contractName = trimmedLineParts[1].split(' ')[0]; + } - if ( - trimmedLineParts[1] !== "" && - trimmedLineParts[2] !== "" && - trimmedLineParts[3] !== "" && - trimmedLineParts[4] !== "" && - trimmedLineParts[5] !== "" && - trimmedLineParts[6] !== "" && - !trimmedLineParts[1].split("").every((char) => char === "-") && - !trimmedLineParts[2].split("").every((char) => char === "-") && - !trimmedLineParts[3].split("").every((char) => char === "-") && - !trimmedLineParts[4].split("").every((char) => char === "-") && - !trimmedLineParts[5].split("").every((char) => char === "-") && - !trimmedLineParts[6].split("").every((char) => char === "-") && - trimmedLineParts[1] !== "Function Name" - ) { - const functionName = trimmedLineParts[1]; - const min = BigInt(trimmedLineParts[2]); - const average = BigInt(trimmedLineParts[3]); - const median = BigInt(trimmedLineParts[4]); - const max = BigInt(trimmedLineParts[5]); - - const splittedContractName = contractName.split(":"); - const totalPath = `${cwd}/${splittedContractName[0]}`; - if (!reports.has(totalPath)) { - reports.set(totalPath, new Map()); - } - if (reports.get(totalPath)?.has(splittedContractName[1])) { - reports - .get(totalPath) - ?.get(splittedContractName[1]) - ?.set(functionName, { min, average, median, max }); - } else { - reports - .get(totalPath) - ?.set(splittedContractName[1], new Map()); - reports - .get(totalPath) - ?.get(splittedContractName[1]) - ?.set(functionName, { min, average, median, max }); - } + if ( + trimmedLineParts[1] !== '' && + trimmedLineParts[2] !== '' && + trimmedLineParts[3] !== '' && + trimmedLineParts[4] !== '' && + trimmedLineParts[5] !== '' && + trimmedLineParts[6] !== '' && + !trimmedLineParts[1].split('').every((char) => char === '-') && + !trimmedLineParts[2].split('').every((char) => char === '-') && + !trimmedLineParts[3].split('').every((char) => char === '-') && + !trimmedLineParts[4].split('').every((char) => char === '-') && + !trimmedLineParts[5].split('').every((char) => char === '-') && + !trimmedLineParts[6].split('').every((char) => char === '-') && + trimmedLineParts[1] !== 'Function Name' + ) { + const functionName = trimmedLineParts[1]; + const min = BigInt(trimmedLineParts[2]); + const average = BigInt(trimmedLineParts[3]); + const median = BigInt(trimmedLineParts[4]); + const max = BigInt(trimmedLineParts[5]); + + const splittedContractName = contractName.split(':'); + const totalPath = `${cwd}/${splittedContractName[0]}`; + if (!reports.has(totalPath)) { + reports.set(totalPath, new Map()); + } + if (reports.get(totalPath)?.has(splittedContractName[1])) { + reports.get(totalPath)?.get(splittedContractName[1])?.set(functionName, { min, average, median, max }); + } else { + reports.get(totalPath)?.set(splittedContractName[1], new Map()); + reports.get(totalPath)?.get(splittedContractName[1])?.set(functionName, { min, average, median, max }); } } - }), - ); - resolve(); - }, - ), + } + }), + ); + resolve(); + }), ); // Go through the reports and create the decorations @@ -116,10 +107,7 @@ async function gasReportTests(cwd: string): Promise { const content = await vscode.workspace.fs.readFile(vscode.Uri.file(path)); for (const [contract, _] of report) { - const functionsInsideContract = getFunctionsInsideContract( - content.toString(), - contract, - ); + const functionsInsideContract = getFunctionsInsideContract(content.toString(), contract); for (const func of functionsInsideContract) { const gas = report.get(contract)?.get(func.name)?.average; if (!gas) { @@ -128,10 +116,7 @@ async function gasReportTests(cwd: string): Promise { let range = new vscode.Range( new vscode.Position(func.line - 1, 0), - new vscode.Position( - func.line - 1, - content.toString().split("\n")[func.line - 1].length, - ), + new vscode.Position(func.line - 1, content.toString().split('\n')[func.line - 1].length), ); let decoration = { range, @@ -156,59 +141,47 @@ async function getGasReport(contracts: string[], cwd: string): Promise { await Promise.all( contracts.map(async (contract) => { return await new Promise((resolve, reject) => { - exec( - `forge inspect "${contract}" gasEstimates`, - { cwd }, - (error: any, _stdout: any, _stderr: any) => { - if (error) { - console.log("error", error); - reject(error); - } + exec(`forge inspect "${contract}" gasEstimates`, { cwd }, (error: any, _stdout: any, _stderr: any) => { + if (error) { + console.log('error', error); + reject(error); + } if (_stdout === "null\n") { resolve(); } - const json = JSON.parse(_stdout); - const internalFunctions = Object.keys(json.internal); - const externalFunctions = Object.keys(json.external); - - // Go through the internal and external functions and create the report - internalFunctions.forEach((functionName) => { - const cleanFunctionName = functionName.split("(")[0]; - const res: string = json.internal[functionName]; - if (res !== "infinite") { - if (report.has(contract)) { - report - .get(contract) - ?.set(cleanFunctionName, { average: BigInt(res) }); - } else { - report.set(contract, new Map()); - report - .get(contract) - ?.set(cleanFunctionName, { average: BigInt(res) }); - } + const json = JSON.parse(_stdout); + const internalFunctions = Object.keys(json.internal); + const externalFunctions = Object.keys(json.external); + + // Go through the internal and external functions and create the report + internalFunctions.forEach((functionName) => { + const cleanFunctionName = functionName.split('(')[0]; + const res: string = json.internal[functionName]; + if (res !== 'infinite') { + if (report.has(contract)) { + report.get(contract)?.set(cleanFunctionName, { average: BigInt(res) }); + } else { + report.set(contract, new Map()); + report.get(contract)?.set(cleanFunctionName, { average: BigInt(res) }); } - }); - externalFunctions.forEach((functionName) => { - const cleanFunctionName = functionName.split("(")[0]; - const res: string = json.external[functionName]; - if (res !== "infinite") { - if (report.has(contract)) { - report - .get(contract) - ?.set(cleanFunctionName, { average: BigInt(res) }); - } else { - report.set(contract, new Map()); - report - .get(contract) - ?.set(cleanFunctionName, { average: BigInt(res) }); - } + } + }); + externalFunctions.forEach((functionName) => { + const cleanFunctionName = functionName.split('(')[0]; + const res: string = json.external[functionName]; + if (res !== 'infinite') { + if (report.has(contract)) { + report.get(contract)?.set(cleanFunctionName, { average: BigInt(res) }); + } else { + report.set(contract, new Map()); + report.get(contract)?.set(cleanFunctionName, { average: BigInt(res) }); } - }); - resolve(); - }, - ); + } + }); + resolve(); + }); }); }), ); @@ -218,24 +191,21 @@ async function getGasReport(contracts: string[], cwd: string): Promise { function getXthWord(line: string, index: number): string { return ( line - .split(" ") + .split(' ') .filter((str) => !!str.length) - .at(index) || "" + .at(index) || '' ); } function getContractsInsideFile(content: string, path: string): string[] { const contracts: string[] = []; - const lines = content.split("\n"); + const lines = content.split('\n'); lines.forEach((line) => { - if (getXthWord(line, 0) === "contract") { + if (getXthWord(line, 0) === 'contract') { const contractName = getXthWord(line, 1); contracts.push(`${path}:${contractName}`); - } else if ( - getXthWord(line, 0) === "abstract" && - getXthWord(line, 1) === "contract" - ) { + } else if (getXthWord(line, 0) === 'abstract' && getXthWord(line, 1) === 'contract') { const contractName = getXthWord(line, 2); contracts.push(`${path}:${contractName}`); } @@ -244,33 +214,29 @@ function getContractsInsideFile(content: string, path: string): string[] { } // Get all the functions and abstract functions inside a contract with their line number -function getFunctionsInsideContract( - content: string, - contractName: string, -): Function[] { +function getFunctionsInsideContract(content: string, contractName: string): Function[] { const functions: Function[] = []; - const lines = content.split("\n"); + const lines = content.split('\n'); let start = false; let bracketsCount = 0; - let currentContractName = ""; + let currentContractName = ''; lines.forEach((line, index) => { const firstWord = getXthWord(line, 0); const secondWord = getXthWord(line, 1); - if (firstWord === "contract") { + if (firstWord === 'contract') { currentContractName = secondWord; if (contractName === currentContractName) { start = true; } } if (start) { - bracketsCount += - line.split("{").length - 1 - (line.split("}").length - 1); + bracketsCount += line.split('{').length - 1 - (line.split('}').length - 1); if (bracketsCount === -1) { return functions; } - if (firstWord === "function") { - const functionName = secondWord.split("(")[0]; + if (firstWord === 'function') { + const functionName = secondWord.split('(')[0]; functions.push({ name: functionName, line: index + 1, @@ -283,10 +249,7 @@ function getFunctionsInsideContract( } // compute the decorations to send based on forge inspection -async function gasReport( - content: string, - path: string, -): Promise { +async function gasReport(content: string, path: string): Promise { const workspace = vscode.workspace.workspaceFolders?.[0]; const workspacePath = workspace?.uri.path; if (!workspacePath) { @@ -297,10 +260,7 @@ async function gasReport( const functionsPerContract: Map = new Map(); contracts.map((contract) => { - const functions = getFunctionsInsideContract( - content, - contract.split(":")[1], - ); + const functions = getFunctionsInsideContract(content, contract.split(':')[1]); functionsPerContract.set(contract, functions); }); @@ -314,10 +274,7 @@ async function gasReport( let range = new vscode.Range( new vscode.Position(func.line - 1, 0), - new vscode.Position( - func.line - 1, - content.split("\n")[func.line - 1].length, - ), + new vscode.Position(func.line - 1, content.split('\n')[func.line - 1].length), ); let decoration = { @@ -356,12 +313,18 @@ async function showReport( } } -export function registerGasEstimation(context: vscode.ExtensionContext): {openDisposable:Disposable, SaveDisposable:Disposable, visibleTextEditorsDisposable:Disposable, activeTextEditorDisposable:Disposable, commandDisposable:Disposable} { +export function registerGasEstimation(context: vscode.ExtensionContext): { + openDisposable: Disposable; + SaveDisposable: Disposable; + visibleTextEditorsDisposable: Disposable; + activeTextEditorDisposable: Disposable; + commandDisposable: Disposable; +} { const forgeInstalled = isForgeInstalled(); const decorationType = vscode.window.createTextEditorDecorationType({ after: { - color: "rgba(255, 255, 255, 0.5)", + color: 'rgba(255, 255, 255, 0.5)', }, }); @@ -375,12 +338,12 @@ export function registerGasEstimation(context: vscode.ExtensionContext): {openDi if (!workspacePath) { return; } - const cleanPath = document.uri.path.replace(workspacePath, ""); + const cleanPath = document.uri.path.replace(workspacePath, ''); if ( - cleanPath.includes("lib") || - cleanPath.includes("test") || - cleanPath.includes("script") || - cleanPath.includes(".git") || + cleanPath.includes('lib') || + cleanPath.includes('test') || + cleanPath.includes('script') || + cleanPath.includes('.git') || !forgeInstalled ) { return; @@ -399,12 +362,12 @@ export function registerGasEstimation(context: vscode.ExtensionContext): {openDi if (!workspacePath) { return; } - const cleanPath = document.uri.path.replace(workspacePath, ""); + const cleanPath = document.uri.path.replace(workspacePath, ''); if ( - cleanPath.includes("lib") || - cleanPath.includes("test") || - cleanPath.includes("script") || - cleanPath.includes(".git") || + cleanPath.includes('lib') || + cleanPath.includes('test') || + cleanPath.includes('script') || + cleanPath.includes('.git') || !forgeInstalled ) { return; @@ -430,11 +393,9 @@ export function registerGasEstimation(context: vscode.ExtensionContext): {openDi } }); - const onDidcommandDisposable = vscode.commands.registerCommand("osmium.gas-estimation", async function () { + const onDidcommandDisposable = vscode.commands.registerCommand('osmium.gas-estimation', async function () { if (vscode.workspace.workspaceFolders?.[0].uri.fsPath) { - const report = await gasReportTests( - vscode.workspace.workspaceFolders?.[0].uri.fsPath, - ); + const report = await gasReportTests(vscode.workspace.workspaceFolders?.[0].uri.fsPath); reportsSaved = report; } vscode.window.visibleTextEditors.forEach((editor) => { @@ -448,5 +409,11 @@ export function registerGasEstimation(context: vscode.ExtensionContext): {openDi context.subscriptions.push(onDidChangeActiveTextEditorDisposable); context.subscriptions.push(onDidcommandDisposable); - return {openDisposable:onDidOpenDisposable, SaveDisposable:onDidSaveDisposable, visibleTextEditorsDisposable:onDidChangeVisibleTextEditorsDisposable, activeTextEditorDisposable:onDidChangeActiveTextEditorDisposable, commandDisposable:onDidcommandDisposable}; + return { + openDisposable: onDidOpenDisposable, + SaveDisposable: onDidSaveDisposable, + visibleTextEditorsDisposable: onDidChangeVisibleTextEditorsDisposable, + activeTextEditorDisposable: onDidChangeActiveTextEditorDisposable, + commandDisposable: onDidcommandDisposable, + }; } diff --git a/vscode/src/linter.ts b/vscode/src/linter.ts index a8918cc..34532b2 100644 --- a/vscode/src/linter.ts +++ b/vscode/src/linter.ts @@ -1,23 +1,13 @@ -import * as path from "path"; -import * as os from "os"; -import { workspace, ExtensionContext, Uri } from "vscode"; -import { - LanguageClient, - LanguageClientOptions, - ServerOptions, - TransportKind, -} from "vscode-languageclient/node"; -import { TextDecoder } from "util"; +import * as path from 'path'; +import * as os from 'os'; +import { workspace, ExtensionContext, Uri } from 'vscode'; +import { LanguageClient, LanguageClientOptions, ServerOptions, TransportKind } from 'vscode-languageclient/node'; +import { TextDecoder } from 'util'; -export async function createLinterClient( - context: ExtensionContext, -): Promise { +export async function createLinterClient(context: ExtensionContext): Promise { // The server is implemented in node const serverBinary = context.asAbsolutePath( - path.join( - "dist", - os.platform().startsWith("win") ? "linter-server.exe" : "linter-server", - ), + path.join('dist', os.platform().startsWith('win') ? 'linter-server.exe' : 'linter-server'), ); // If the extension is launched in debug mode then the debug server options are used @@ -33,22 +23,22 @@ export async function createLinterClient( // Options to control the language client const clientOptions: LanguageClientOptions = { // Register the server for plain text documents - documentSelector: [{ scheme: "file", language: "solidity" }], + documentSelector: [{ scheme: 'file', language: 'solidity' }], synchronize: { // Notify the server about file changes to '.clientrc files contained in the workspace - fileEvents: workspace.createFileSystemWatcher("**/.solidhunter.json"), + fileEvents: workspace.createFileSystemWatcher('**/.solidhunter.json'), }, }; // Create the language client and start the client. const client = new LanguageClient( - "osmium-solidity-linter", - "Osmium Solidity Linter Language Server", + 'osmium-solidity-linter', + 'Osmium Solidity Linter Language Server', serverOptions, clientOptions, ); - client.onRequest("osmium/getContent", async (params: { uri: string }) => { + client.onRequest('osmium/getContent', async (params: { uri: string }) => { const contentUint8 = await workspace.fs.readFile(Uri.parse(params.uri)); const content = new TextDecoder().decode(contentUint8); return { diff --git a/vscode/src/sidebar-provider.ts b/vscode/src/sidebar-provider.ts index 3619b52..323fbb5 100644 --- a/vscode/src/sidebar-provider.ts +++ b/vscode/src/sidebar-provider.ts @@ -33,9 +33,7 @@ export class SidebarProvider implements vscode.WebviewViewProvider { ) {} async _osmiumWatcherCallback(uri: vscode.Uri) { - if (!this._view) { - return; - } + if (!this._view) {return;} const basename = path.basename(uri.fsPath, '.json'); if (basename === 'contracts') { this._interactContractRepository?.load(); @@ -173,7 +171,106 @@ export class SidebarProvider implements vscode.WebviewViewProvider { break; case MessageType.OPEN_PANEL: await vscode.commands.executeCommand('osmium.show-env-panel'); + /*case MessageType.EDIT_WALLETS: + const walletAction = await window.showQuickPick([InputAction.ADD, InputAction.REMOVE], { + title: 'Edit Wallets', + ignoreFocusOut: true, + }); + + if (walletAction === InputAction.ADD) { + const inputs = await this._showInputsBox({ + walletName: 'Enter name', + walletAddress: 'Enter address', + walletPk: 'Enter private key', + walletRpc: 'Enter rpc', + }); + if (!inputs) {return;} + if (!inputs.walletAddress.startsWith('0x') || !inputs.walletPk.startsWith('0x')) {return;} + if (!inputs.walletRpc.startsWith('http') && !inputs.walletRpc.startsWith('ws')) {return;} + + this._walletRepository.createWallet( + inputs.walletName, +
inputs.walletAddress, +
inputs.walletPk, + inputs.walletRpc, + ); + } + + if (walletAction === InputAction.REMOVE) { + const walletName = await window.showQuickPick( + this._walletRepository.getWallets().map((w) => w.name), + { + title: 'Remove wallet', + ignoreFocusOut: true, + }, + ); + if (!walletName) {return;} + this._walletRepository.deleteWallet(walletName); + } break; + case MessageType.EDIT_CONTRACTS: + const contractAction = await window.showQuickPick([InputAction.ADD, InputAction.REMOVE], { + title: 'Edit Wallets', + ignoreFocusOut: true, + }); + + if (contractAction === InputAction.ADD) { + const inputs = await this._showInputsBox({ + contractName: 'Enter name', + contractAddress: 'Enter address', + contractAbi: 'Enter abi', + contractRpc: 'Enter rpc', + contractChainId: 'Enter chain id', + }); + if (!inputs || !inputs.contractAddress.startsWith('0x')) {return;} + if (!inputs.contractRpc.startsWith('http') && !inputs.contractRpc.startsWith('ws')) {return;} + this._interactContractRepository.createContract( +
inputs['contractAddress'], + JSON.parse(inputs['contractAbi']), + parseInt(inputs['contractChainId']), + inputs['contractName'], + inputs['contractRpc'], + ); + } + if (contractAction === InputAction.REMOVE) { + const contractName = await window.showQuickPick( + this._interactContractRepository.getContracts().map((c) => c.name), + { + title: 'Remove contract', + ignoreFocusOut: true, + }, + ); + if (!contractName) {return;} + this._interactContractRepository.deleteContract(contractName); + } + break; + case MessageType.EDIT_ENVIRONMENT: + const environmentAction = await window.showQuickPick([InputAction.ADD, InputAction.REMOVE], { + title: 'Edit environment', + ignoreFocusOut: true, + }); + if (environmentAction === InputAction.ADD) { + const inputs = await this._showInputsBox({ + environmentName: 'Enter name', + environmentRpc: 'Enter rpc', + }); + if (!inputs) {return;} + if (!inputs.environmentRpc.startsWith('http') && !inputs.environmentRpc.startsWith('ws')) {return;} + + this._environmentRepository.createEnvironment(inputs.environmentName, inputs.environmentRpc); + } + if (environmentAction === InputAction.REMOVE) { + const environmentName = await window.showQuickPick( + this._environmentRepository.getEnvironments().map((e) => e.name), + { + title: 'Remove environment', + ignoreFocusOut: true, + }, + ); + if (!environmentName) {return;} + this._environmentRepository.deleteEnvironment(environmentName); + } + break;*/ case MessageType.DEPLOY_SCRIPT: const deployScriptResponse = await this._deploy.deployScript({ environmentId: message.data.environment, diff --git a/vscode/src/slither.ts b/vscode/src/slither.ts index 2f572cd..1929c27 100644 --- a/vscode/src/slither.ts +++ b/vscode/src/slither.ts @@ -1,23 +1,13 @@ -import * as path from "path"; -import * as os from "os"; -import { workspace, ExtensionContext, Uri } from "vscode"; -import { - LanguageClient, - LanguageClientOptions, - ServerOptions, - TransportKind, -} from "vscode-languageclient/node"; -import { TextDecoder } from "util"; +import * as path from 'path'; +import * as os from 'os'; +import { workspace, ExtensionContext, Uri } from 'vscode'; +import { LanguageClient, LanguageClientOptions, ServerOptions, TransportKind } from 'vscode-languageclient/node'; +import { TextDecoder } from 'util'; -export async function createSlitherClient( - context: ExtensionContext, -): Promise { +export async function createSlitherClient(context: ExtensionContext): Promise { // The server is implemented in node const serverBinary = context.asAbsolutePath( - path.join( - "dist", - os.platform().startsWith("win") ? "slither-server.exe" : "slither-server", - ), + path.join('dist', os.platform().startsWith('win') ? 'slither-server.exe' : 'slither-server'), ); // If the extension is launched in debug mode then the debug server options are used @@ -33,22 +23,17 @@ export async function createSlitherClient( // Options to control the language client const clientOptions: LanguageClientOptions = { // Register the server for plain text documents - documentSelector: [{ scheme: "file", language: "solidity" }], + documentSelector: [{ scheme: 'file', language: 'solidity' }], synchronize: { // Notify the server about file changes to '.clientrc files contained in the workspace - fileEvents: workspace.createFileSystemWatcher("**/.solidhunter.json"), + fileEvents: workspace.createFileSystemWatcher('**/.solidhunter.json'), }, }; // Create the language client and start the client. - const client = new LanguageClient( - "osmium-slither", - "Osmium Slither Language Server", - serverOptions, - clientOptions, - ); + const client = new LanguageClient('osmium-slither', 'Osmium Slither Language Server', serverOptions, clientOptions); - client.onRequest("osmium/getContent", async (params: { uri: string }) => { + client.onRequest('osmium/getContent', async (params: { uri: string }) => { const contentUint8 = await workspace.fs.readFile(Uri.parse(params.uri)); const content = new TextDecoder().decode(contentUint8); return { diff --git a/vscode/src/tests-positions.ts b/vscode/src/tests-positions.ts index 378867c..c5c421c 100644 --- a/vscode/src/tests-positions.ts +++ b/vscode/src/tests-positions.ts @@ -1,25 +1,13 @@ -import * as path from "path"; -import * as os from "os"; -import { workspace, ExtensionContext, Uri } from "vscode"; -import { - LanguageClient, - LanguageClientOptions, - ServerOptions, - TransportKind, -} from "vscode-languageclient/node"; -import { TextDecoder } from "util"; +import * as path from 'path'; +import * as os from 'os'; +import { workspace, ExtensionContext, Uri } from 'vscode'; +import { LanguageClient, LanguageClientOptions, ServerOptions, TransportKind } from 'vscode-languageclient/node'; +import { TextDecoder } from 'util'; -export async function createTestsPositionsClient( - context: ExtensionContext, -): Promise { +export async function createTestsPositionsClient(context: ExtensionContext): Promise { // The server is implemented in node const serverBinary = context.asAbsolutePath( - path.join( - "dist", - os.platform().startsWith("win") - ? "tests-positions-server.exe" - : "tests-positions-server", - ), + path.join('dist', os.platform().startsWith('win') ? 'tests-positions-server.exe' : 'tests-positions-server'), ); // If the extension is launched in debug mode then the debug server options are used @@ -44,8 +32,8 @@ export async function createTestsPositionsClient( // Create the language client and start the client. const client = new LanguageClient( - "osmium-tests-positions", - "Osmium Solidity Tests Positions Language Server", + 'osmium-tests-positions', + 'Osmium Solidity Tests Positions Language Server', serverOptions, clientOptions, );