Skip to content

Commit

Permalink
CallSiteValue::get_called_fn_value: return None on indirect calls
Browse files Browse the repository at this point in the history
  • Loading branch information
airwoodix committed Feb 21, 2025
1 parent 97128a9 commit a4229d5
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 11 deletions.
71 changes: 65 additions & 6 deletions src/values/call_site_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ use std::fmt::{self, Display};
use either::Either;
use llvm_sys::core::{
LLVMGetInstructionCallConv, LLVMGetTypeKind, LLVMIsTailCall, LLVMSetInstrParamAlignment,
LLVMSetInstructionCallConv, LLVMSetTailCall, LLVMTypeOf,
LLVMSetInstructionCallConv, LLVMSetTailCall, LLVMTypeOf, LLVMGetCalledValue
};
#[llvm_versions(8..)]
use llvm_sys::core::LLVMGetCalledFunctionType;
#[llvm_versions(18..)]
use llvm_sys::core::{LLVMGetTailCallKind, LLVMSetTailCallKind};
use llvm_sys::prelude::LLVMValueRef;
use llvm_sys::LLVMTypeKind;

use crate::attributes::{Attribute, AttributeLoc};
use crate::types::FunctionType;
#[llvm_versions(18..)]
use crate::values::operand_bundle::OperandBundleIter;
use crate::values::{AsValueRef, BasicValueEnum, FunctionValue, InstructionValue, Value};
Expand Down Expand Up @@ -201,9 +204,12 @@ impl<'ctx> CallSiteValue<'ctx> {

/// Gets the `FunctionValue` this `CallSiteValue` is based on.
///
/// Returns [`None`] if the call this value bases on is indirect or the retrieved function
/// value doesn't have the same type as the underlying call instruction.
///
/// # Example
///
/// ```no_run
/// ```
/// use inkwell::context::Context;
///
/// let context = Context::create();
Expand All @@ -220,12 +226,65 @@ impl<'ctx> CallSiteValue<'ctx> {
///
/// let call_site_value = builder.build_call(fn_value, &[], "my_fn").unwrap();
///
/// assert_eq!(call_site_value.get_called_fn_value(), fn_value);
/// assert_eq!(call_site_value.get_called_fn_value(), Some(fn_value));
/// ```
pub fn get_called_fn_value(self) -> FunctionValue<'ctx> {
use llvm_sys::core::LLVMGetCalledValue;
pub fn get_called_fn_value(self) -> Option<FunctionValue<'ctx>> {
// SAFETY: the passed LLVMValueRef is of type CallSite
let called_value = unsafe { LLVMGetCalledValue(self.as_value_ref()) };

let fn_value = unsafe { FunctionValue::new(called_value) };

// Check that the retrieved function value has the same type as the callee.
// This matches the behavior of the C++ API `CallBase::getCalledFunction`.
// This is only possible on LLVM >=8, where the `LLVMGetCalledFunctionType` API exists.
self.get_called_fn_value_check_type_consistency(fn_value)
}

#[llvm_versions(..8)]
#[inline]
fn get_called_fn_value_check_type_consistency(&self, fn_value: Option<FunctionValue<'ctx>>) -> Option<FunctionValue<'ctx>> {
fn_value
}

#[llvm_versions(8..)]
#[inline]
fn get_called_fn_value_check_type_consistency(&self, fn_value: Option<FunctionValue<'ctx>>) -> Option<FunctionValue<'ctx>> {
match fn_value {
Some(fn_value) if fn_value.get_type() == self.get_called_fn_type() => Some(fn_value),
_ => None,
}
}

unsafe { FunctionValue::new(LLVMGetCalledValue(self.as_value_ref())).expect("This should never be null?") }
/// Gets the type of the function called by the instruction this `CallSiteValue` is based on.
///
/// # Example
///
/// ```
/// use inkwell::context::Context;
///
/// let context = Context::create();
/// let builder = context.create_builder();
/// let module = context.create_module("my_mod");
/// let i32_type = context.i32_type();
/// let fn_type = i32_type.fn_type(&[], false);
/// let fn_value = module.add_function("my_fn", fn_type, None);
///
/// let entry_bb = context.append_basic_block(fn_value, "entry");
/// builder.position_at_end(entry_bb);
///
/// // Recursive call.
/// let call_site_value = builder.build_call(fn_value, &[], "my_fn").unwrap();
///
/// assert_eq!(call_site_value.get_called_fn_type(), fn_type);
/// ```
#[llvm_versions(8..)]
pub fn get_called_fn_type(self) -> FunctionType<'ctx> {
// SAFETY: the passed LLVMValueRef is of type CallSite
let fn_type_ref = unsafe {LLVMGetCalledFunctionType(self.as_value_ref())};

// FIXME?: this assumes that fn_type_ref is not null.
// SAFETY: fn_type_ref is a function type reference.
unsafe { FunctionType::new(fn_type_ref) }
}

/// Counts the number of `Attribute`s on this `CallSiteValue` at an index.
Expand Down
4 changes: 1 addition & 3 deletions src/values/fn_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@ impl<'ctx> FunctionValue<'ctx> {
///
/// The ref must be valid and of type function.
pub unsafe fn new(value: LLVMValueRef) -> Option<Self> {
if value.is_null() {
if value.is_null() || LLVMIsAFunction(value).is_null() {
return None;
}

assert!(!LLVMIsAFunction(value).is_null());

Some(FunctionValue {
fn_value: Value::new(value),
})
Expand Down
2 changes: 1 addition & 1 deletion tests/all/test_attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ fn test_attributes_on_call_site_values() {
assert!(call_site_value
.get_string_attribute(AttributeLoc::Return, "my_key")
.is_none());
assert_eq!(call_site_value.get_called_fn_value(), fn_value);
assert_eq!(call_site_value.get_called_fn_value(), Some(fn_value));

call_site_value.set_alignment_attribute(AttributeLoc::Return, 16);

Expand Down
45 changes: 44 additions & 1 deletion tests/all/test_values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use inkwell::attributes::AttributeLoc;
#[llvm_versions(7..)]
use inkwell::comdat::ComdatSelectionKind;
use inkwell::context::Context;
use inkwell::memory_buffer::MemoryBuffer;
use inkwell::module::Linkage::*;
use inkwell::types::{AnyTypeEnum, StringRadix, VectorType};
#[llvm_versions(18..)]
use inkwell::values::OperandBundle;
use inkwell::values::{AnyValue, InstructionOpcode::*, FIRST_CUSTOM_METADATA_KIND_ID};
use inkwell::values::{AnyValue, CallSiteValue, InstructionOpcode::*, FIRST_CUSTOM_METADATA_KIND_ID};
use inkwell::{AddressSpace, DLLStorageClass, GlobalVisibility, ThreadLocalMode};

#[llvm_versions(18..)]
Expand Down Expand Up @@ -147,6 +148,48 @@ fn test_call_site_operand_bundles() {
args_iter.for_each(|arg| assert!(arg.into_int_value().is_const()));
}

/// Check that `CallSiteValue::get_called_fn_value` returns `None` if the underlying call is indirect.
/// Regression test for inkwell#571.
/// Retricted to LLVM >= 15, since the input IR uses opaque pointers.
#[llvm_versions(15..)]
#[test]
fn test_call_site_function_value_indirect_call() {
// ```c
// void dummy_fn();
//
// void my_fn() {
// void (*fn_ptr)(void) = &dummy_fn;
// (*fn_ptr)();
// }
// ```

let llvm_ir = r#"
source_filename = "my_mod";
define void @my_fn() {
entry:
%0 = alloca ptr, align 8
store ptr @dummy_fn, ptr %0, align 8
%1 = load ptr, ptr %0, align 8
call void %1()
ret void
}
declare void @dummy_fn();
"#;

let memory_buffer = MemoryBuffer::create_from_memory_range_copy(llvm_ir.as_bytes(), "my_mod");
let context = Context::create();
let module = context.create_module_from_ir(memory_buffer).unwrap();

let main_fn = module.get_function("my_fn").unwrap();
let inst = main_fn.get_last_basic_block().unwrap().get_instructions().nth(3).unwrap();
let call_site_value = CallSiteValue::try_from(inst).unwrap();

let fn_value = call_site_value.get_called_fn_value();
assert!(fn_value.is_none());
}

#[test]
fn test_set_get_name() {
let context = Context::create();
Expand Down

0 comments on commit a4229d5

Please sign in to comment.