diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index d320517100..5f2e369e20 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -595,6 +595,7 @@ RUN(NAME test_set_discard LABELS cpython llvm llvm_jit) RUN(NAME test_set_from_list LABELS cpython llvm llvm_jit) RUN(NAME test_set_clear LABELS cpython llvm) RUN(NAME test_set_pop LABELS cpython llvm) +RUN(NAME test_bytes_01 LABELS cpython llvm llvm_jit) RUN(NAME test_global_set LABELS cpython llvm llvm_jit) RUN(NAME test_for_loop LABELS cpython llvm llvm_jit c) RUN(NAME modules_01 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x64) diff --git a/integration_tests/test_bytes_01.py b/integration_tests/test_bytes_01.py new file mode 100644 index 0000000000..6b5e2b9537 --- /dev/null +++ b/integration_tests/test_bytes_01.py @@ -0,0 +1,20 @@ +def f(): + a: bytes = b"This is a test string" + b: bytes = b"This is another test string" + c: bytes = b"""Bigger test string with docstrings + Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do + eiusmod tempor incididunt ut labore et dolore magna aliqua. """ + + +def g(a: bytes) -> bytes: + return a + + +def h() -> bytes: + bar: bytes + bar = g(b"fiwabcd") + return b"12jw19\\xq0" + + +f() +h() diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index 5ea9a482b5..9b8da38482 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -140,6 +140,7 @@ expr | StringOrd(expr arg, ttype type, expr? value) | StringChr(expr arg, ttype type, expr? value) | StringFormat(expr fmt, expr* args, string_format_kind kind, ttype type, expr? value) + | BytesConstant(string s, ttype type) | CPtrCompare(expr left, cmpop op, expr right, ttype type, expr? value) | SymbolicCompare(expr left, cmpop op, expr right, ttype type, expr? value) | DictConstant(expr* keys, expr* values, ttype type) @@ -198,6 +199,7 @@ ttype | Real(int kind) | Complex(int kind) | Character(int kind, int len, expr? len_expr) + | Byte(int kind, int len, expr? len_expr) | Logical(int kind) | Set(ttype type) | List(ttype type) diff --git a/src/libasr/asdl_cpp.py b/src/libasr/asdl_cpp.py index 9463cb9d1f..05dcde949f 100644 --- a/src/libasr/asdl_cpp.py +++ b/src/libasr/asdl_cpp.py @@ -2,8 +2,9 @@ Generate C++ AST node definitions from an ASDL description. """ -import sys import os +import sys + import asdl diff --git a/src/libasr/asr_utils.h b/src/libasr/asr_utils.h index 42d64c4c1c..dae63f6833 100644 --- a/src/libasr/asr_utils.h +++ b/src/libasr/asr_utils.h @@ -207,6 +207,9 @@ static inline int extract_kind_from_ttype_t(const ASR::ttype_t* type) { case ASR::ttypeType::Character: { return ASR::down_cast(type)->m_kind; } + case ASR::ttypeType::Byte: { + return ASR::down_cast(type)->m_kind; + } case ASR::ttypeType::Logical: { return ASR::down_cast(type)->m_kind; } @@ -251,6 +254,10 @@ static inline void set_kind_to_ttype_t(ASR::ttype_t* type, int kind) { ASR::down_cast(type)->m_kind = kind; break; } + case ASR::ttypeType::Byte: { + ASR::down_cast(type)->m_kind = kind; + break; + } case ASR::ttypeType::Logical: { ASR::down_cast(type)->m_kind = kind; break; @@ -542,6 +549,9 @@ static inline std::string type_to_str(const ASR::ttype_t *t) case ASR::ttypeType::Character: { return "character"; } + case ASR::ttypeType::Byte: { + return "byte"; + } case ASR::ttypeType::Tuple: { return "tuple"; } @@ -990,7 +1000,8 @@ static inline bool is_value_constant(ASR::expr_t *a_value) { case ASR::exprType::ImpliedDoLoop: case ASR::exprType::PointerNullConstant: case ASR::exprType::ArrayConstant: - case ASR::exprType::StringConstant: { + case ASR::exprType::StringConstant: + case ASR::exprType::BytesConstant: { return true; } case ASR::exprType::RealBinOp: @@ -1421,6 +1432,9 @@ static inline std::string get_type_code(const ASR::ttype_t *t, bool use_undersco case ASR::ttypeType::Character: { return "str"; } + case ASR::ttypeType::Byte: { + return "bytes"; + } case ASR::ttypeType::Tuple: { ASR::Tuple_t *tup = ASR::down_cast(t); std::string result = "tuple"; @@ -1608,6 +1622,9 @@ static inline std::string type_to_str_python(const ASR::ttype_t *t, case ASR::ttypeType::Character: { return "str"; } + case ASR::ttypeType::Byte: { + return "bytes"; + } case ASR::ttypeType::Tuple: { ASR::Tuple_t *tup = ASR::down_cast(t); std::string result = "tuple["; @@ -2148,6 +2165,7 @@ inline size_t extract_dimensions_from_ttype(ASR::ttype_t *x, case ASR::ttypeType::Real: case ASR::ttypeType::Complex: case ASR::ttypeType::Character: + case ASR::ttypeType::Byte: case ASR::ttypeType::Logical: case ASR::ttypeType::StructType: case ASR::ttypeType::Enum: @@ -2419,6 +2437,7 @@ inline bool ttype_set_dimensions(ASR::ttype_t** x, case ASR::ttypeType::Real: case ASR::ttypeType::Complex: case ASR::ttypeType::Character: + case ASR::ttypeType::Byte: case ASR::ttypeType::Logical: case ASR::ttypeType::StructType: case ASR::ttypeType::Enum: @@ -2540,6 +2559,12 @@ static inline ASR::ttype_t* duplicate_type(Allocator& al, const ASR::ttype_t* t, tnew->m_kind, tnew->m_len, tnew->m_len_expr)); break; } + case ASR::ttypeType::Byte: { + ASR::Byte_t* tnew = ASR::down_cast(t); + t_ = ASRUtils::TYPE(ASR::make_Byte_t(al, t->base.loc, + tnew->m_kind, tnew->m_len, tnew->m_len_expr)); + break; + } case ASR::ttypeType::StructType: { ASR::StructType_t* tnew = ASR::down_cast(t); t_ = ASRUtils::TYPE(ASR::make_StructType_t(al, t->base.loc, @@ -2696,6 +2721,11 @@ static inline ASR::ttype_t* duplicate_type_without_dims(Allocator& al, const ASR return ASRUtils::TYPE(ASR::make_Character_t(al, loc, tnew->m_kind, tnew->m_len, tnew->m_len_expr)); } + case ASR::ttypeType::Byte: { + ASR::Byte_t* tnew = ASR::down_cast(t); + return ASRUtils::TYPE(ASR::make_Byte_t(al, loc, + tnew->m_kind, tnew->m_len, tnew->m_len_expr)); + } case ASR::ttypeType::StructType: { ASR::StructType_t* tstruct = ASR::down_cast(t); return ASRUtils::TYPE(ASR::make_StructType_t(al, t->base.loc, @@ -3123,6 +3153,11 @@ inline bool types_equal(ASR::ttype_t *a, ASR::ttype_t *b, ASR::Character_t *b2 = ASR::down_cast(b); return (a2->m_kind == b2->m_kind); } + case (ASR::ttypeType::Byte) : { + ASR::Byte_t *a2 = ASR::down_cast(a); + ASR::Byte_t *b2 = ASR::down_cast(b); + return (a2->m_kind == b2->m_kind); + } case (ASR::ttypeType::List) : { ASR::List_t *a2 = ASR::down_cast(a); ASR::List_t *b2 = ASR::down_cast(b); @@ -3306,6 +3341,11 @@ inline bool types_equal_with_substitution(ASR::ttype_t *a, ASR::ttype_t *b, ASR::Character_t *b2 = ASR::down_cast(b); return (a2->m_kind == b2->m_kind); } + case (ASR::ttypeType::Byte) : { + ASR::Byte_t *a2 = ASR::down_cast(a); + ASR::Byte_t *b2 = ASR::down_cast(b); + return (a2->m_kind == b2->m_kind); + } case (ASR::ttypeType::List) : { ASR::List_t *a2 = ASR::down_cast(a); ASR::List_t *b2 = ASR::down_cast(b); diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index ec8a8b0205..6129f3fedf 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -82,6 +82,23 @@ void string_init(llvm::LLVMContext &context, llvm::Module &module, builder.CreateCall(fn, args); } +void bytes_init(llvm::LLVMContext &context, llvm::Module &module, + llvm::IRBuilder<> &builder, llvm::Value* arg_size, llvm::Value* arg_bytes) { + std::string func_name = "_lfortran_bytes_init"; + llvm::Function *fn = module.getFunction(func_name); + if (!fn) { + llvm::FunctionType *function_type = llvm::FunctionType::get( + llvm::Type::getVoidTy(context), { + llvm::Type::getInt32Ty(context), + llvm::Type::getInt8PtrTy(context) + }, true); + fn = llvm::Function::Create(function_type, + llvm::Function::ExternalLinkage, func_name, module); + } + std::vector args = {arg_size, arg_bytes}; + builder.CreateCall(fn, args); +} + class ASRToLLVMVisitor : public ASR::BaseVisitor { private: @@ -143,7 +160,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor bool prototype_only; llvm::StructType *complex_type_4, *complex_type_8; llvm::StructType *complex_type_4_ptr, *complex_type_8_ptr; - llvm::PointerType *character_type; + llvm::PointerType *character_type, *byte_type; llvm::PointerType *list_type; std::vector struct_type_stack; @@ -910,6 +927,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor complex_type_4_ptr = llvm_utils->complex_type_4_ptr; complex_type_8_ptr = llvm_utils->complex_type_8_ptr; character_type = llvm_utils->character_type; + byte_type = llvm_utils->character_type; list_type = llvm::Type::getInt8PtrTy(context); llvm::Type* bound_arg = static_cast(arr_descr->get_dimension_descriptor_type(true)); @@ -948,7 +966,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor prototype_only = false; // TODO: handle dependencies across modules and main program - +; // Then do all the modules in the right order std::vector build_order = determine_module_dependencies(x); @@ -2879,6 +2897,25 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } llvm_symtab[h] = ptr; + } else if (x.m_type->type == ASR::ttypeType::Byte) { + llvm::Constant *ptr = module->getOrInsertGlobal(x.m_name, + character_type); + if (!external) { + if (init_value) { + module->getNamedGlobal(x.m_name)->setInitializer( + init_value); + } else { + module->getNamedGlobal(x.m_name)->setInitializer( + llvm::Constant::getNullValue(character_type) + ); + ASR::Byte_t *t = down_cast(x.m_type); + if( t->m_len >= 0 ) { + strings_to_be_allocated.insert(std::pair(ptr, llvm::ConstantInt::get( + context, llvm::APInt(32, t->m_len+1)))); + } + } + } + llvm_symtab[h] = ptr; } else if( x.m_type->type == ASR::ttypeType::CPtr ) { llvm::Type* void_ptr = llvm::Type::getVoidTy(context)->getPointerTo(); llvm::Constant *ptr = module->getOrInsertGlobal(x.m_name, @@ -3889,6 +3926,36 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } else { throw CodeGenError("Unsupported len value in ASR " + std::to_string(strlen)); } + } else if (is_a(*v->m_type) && !is_array_type && !is_list) { + ASR::Byte_t *t = down_cast(v->m_type); + target_var = ptr; + int byte_len = t->m_len; + if (byte_len >= 0 || byte_len == -3) { + llvm::Value *arg_size; + if (byte_len == -3) { + LCOMPILERS_ASSERT(t->m_len_expr) + this->visit_expr(*t->m_len_expr); + arg_size = builder->CreateAdd(builder->CreateSExtOrTrunc(tmp, + llvm::Type::getInt32Ty(context)), + llvm::ConstantInt::get(context, llvm::APInt(32, 1)) ); + } else { + // Compile time length + arg_size = llvm::ConstantInt::get(context, + llvm::APInt(32, byte_len+1)); + } + llvm::Value *init_value = LLVM::lfortran_malloc(context, *module, *builder, arg_size); + string_init(context, *module, *builder, arg_size, init_value); + builder->CreateStore(init_value, target_var); + if (v->m_intent == intent_local) { + strings_to_be_deallocated.push_back(al, CreateLoad(target_var)); + } + } else if (byte_len == -2) { + // Allocatable string. Initialize to `nullptr` (unallocated) + llvm::Value *init_value = llvm::Constant::getNullValue(type); + builder->CreateStore(init_value, target_var); + } else { + throw CodeGenError("Unsupported bytes len value in ASR " + std::to_string(byte_len)); + } } else if (is_list) { ASR::List_t* asr_list = ASR::down_cast(v->m_type); std::string type_code = ASRUtils::get_type_code(asr_list->m_type); @@ -7072,6 +7139,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = builder->CreateGlobalStringPtr(x.m_s); } + void visit_BytesConstant(const ASR::BytesConstant_t &x) { + tmp = builder->CreateGlobalStringPtr(x.m_s); + } + inline void fetch_ptr(ASR::Variable_t* x) { uint32_t x_h = get_hash((ASR::asr_t*)x); LCOMPILERS_ASSERT(llvm_symtab.find(x_h) != llvm_symtab.end()); @@ -7128,6 +7199,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor case ASR::ttypeType::Complex: case ASR::ttypeType::StructType: case ASR::ttypeType::Character: + case ASR::ttypeType::Byte: case ASR::ttypeType::Logical: case ASR::ttypeType::Class: { if( t2->type == ASR::ttypeType::StructType ) { @@ -8848,6 +8920,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor target_type = character_type; break; } + case (ASR::ttypeType::Byte) : { + ASR::Variable_t *orig_arg = nullptr; + if( func_subrout->type == ASR::symbolType::Function ) { + ASR::Function_t* func = down_cast(func_subrout); + orig_arg = ASRUtils::EXPR2VAR(func->m_args[i]); + } else { + throw CodeGenError("ICE: expected func_subrout->type == ASR::symbolType::Function."); + } + if (orig_arg->m_abi == ASR::abiType::BindC) { + character_bindc = true; + } + + target_type = character_type; + break; + } case (ASR::ttypeType::Logical) : target_type = llvm::Type::getInt1Ty(context); break; diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index dde91aa6d9..f2813dc4e2 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -192,6 +192,10 @@ namespace LCompilers { llvm_mem_type = character_type; break; } + case ASR::ttypeType::Byte: { + llvm_mem_type = character_type; + break; + } case ASR::ttypeType::CPtr: { llvm_mem_type = llvm::Type::getVoidTy(context)->getPointerTo(); break; @@ -509,6 +513,10 @@ namespace LCompilers { el_type = character_type; break; } + case ASR::ttypeType::Byte: { + el_type = character_type; + break; + } default: LCOMPILERS_ASSERT(false); break; @@ -753,6 +761,16 @@ namespace LCompilers { } break; } + case (ASR::ttypeType::Byte) : { + ASR::Byte_t* v_type = ASR::down_cast(asr_type); + a_kind = v_type->m_kind; + if (arg_m_abi == ASR::abiType::BindC) { + type = character_type; + } else { + type = character_type->getPointerTo(); + } + break; + } case (ASR::ttypeType::Logical) : { ASR::Logical_t* v_type = ASR::down_cast(asr_type); a_kind = v_type->m_kind; @@ -1005,6 +1023,9 @@ namespace LCompilers { case (ASR::ttypeType::Character) : return_type = character_type; break; + case (ASR::ttypeType::Byte) : + return_type = character_type; + break; case (ASR::ttypeType::Logical) : return_type = llvm::Type::getInt1Ty(context); break; @@ -1203,6 +1224,9 @@ namespace LCompilers { case (ASR::ttypeType::Character) : return_type = character_type; break; + case (ASR::ttypeType::Byte) : + return_type = character_type; + break; case (ASR::ttypeType::Logical) : return_type = llvm::Type::getInt1Ty(context); break; @@ -1422,6 +1446,12 @@ namespace LCompilers { llvm_type = character_type; break; } + case (ASR::ttypeType::Byte) : { + ASR::Byte_t* v_type = ASR::down_cast(asr_type); + a_kind = v_type->m_kind; + llvm_type = character_type; + break; + } case (ASR::ttypeType::Logical) : { ASR::Logical_t* v_type = ASR::down_cast(asr_type); a_kind = v_type->m_kind; @@ -1708,6 +1738,51 @@ namespace LCompilers { llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i)); return builder->CreateICmpEQ(l, r); } + case ASR::ttypeType::Byte: { + get_builder0() + str_cmp_itr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + llvm::Value* null_char = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), + llvm::APInt(8, '\0')); + llvm::Value* idx = str_cmp_itr; + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)), + idx); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + start_new_block(loophead); + { + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + llvm::Value* l = LLVM::CreateLoad(*builder, create_ptr_gep(left, i)); + llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i)); + llvm::Value *cond = builder->CreateAnd( + builder->CreateICmpNE(l, null_char), + builder->CreateICmpNE(r, null_char) + ); + cond = builder->CreateAnd(cond, builder->CreateICmpEQ(l, r)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + start_new_block(loopbody); + { + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, i, idx); + } + + builder->CreateBr(loophead); + + // end + start_new_block(loopend); + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + llvm::Value* l = LLVM::CreateLoad(*builder, create_ptr_gep(left, i)); + llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i)); + return builder->CreateICmpEQ(l, r); + } case ASR::ttypeType::Tuple: { ASR::Tuple_t* tuple_type = ASR::down_cast(asr_type); return tuple_api->check_tuple_equality(left, right, tuple_type, context, @@ -1857,6 +1932,72 @@ namespace LCompilers { llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i)); return builder->CreateICmpULT(l, r); } + case ASR::ttypeType::Byte: { + get_builder0() + str_cmp_itr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + llvm::Value* null_char = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), + llvm::APInt(8, '\0')); + llvm::Value* idx = str_cmp_itr; + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)), + idx); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + start_new_block(loophead); + { + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + llvm::Value* l = LLVM::CreateLoad(*builder, create_ptr_gep(left, i)); + llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i)); + llvm::Value *cond = builder->CreateAnd( + builder->CreateICmpNE(l, null_char), + builder->CreateICmpNE(r, null_char) + ); + switch( overload_id ) { + case 0: { + pred = llvm::CmpInst::Predicate::ICMP_ULT; + break; + } + case 1: { + pred = llvm::CmpInst::Predicate::ICMP_ULE; + break; + } + case 2: { + pred = llvm::CmpInst::Predicate::ICMP_UGT; + break; + } + case 3: { + pred = llvm::CmpInst::Predicate::ICMP_UGE; + break; + } + default: { + throw CodeGenError("Un-recognized overload-id: " + std::to_string(overload_id)); + } + } + cond = builder->CreateAnd(cond, builder->CreateICmp(pred, l, r)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + start_new_block(loopbody); + { + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, i, idx); + } + + builder->CreateBr(loophead); + + // end + start_new_block(loopend); + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + llvm::Value* l = LLVM::CreateLoad(*builder, create_ptr_gep(left, i)); + llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i)); + return builder->CreateICmpULT(l, r); + } case ASR::ttypeType::Tuple: { ASR::Tuple_t* tuple_type = ASR::down_cast(asr_type); return tuple_api->check_tuple_inequality(left, right, tuple_type, context, @@ -1917,6 +2058,7 @@ namespace LCompilers { break ; }; case ASR::ttypeType::Character: + case ASR::ttypeType::Byte: case ASR::ttypeType::FunctionType: case ASR::ttypeType::CPtr: { LLVM::CreateStore(*builder, src, dest); @@ -2002,6 +2144,7 @@ namespace LCompilers { case ASR::ttypeType::Logical: case ASR::ttypeType::Complex: case ASR::ttypeType::Character: + case ASR::ttypeType::Byte: case ASR::ttypeType::FunctionType: case ASR::ttypeType::CPtr: case ASR::ttypeType::Allocatable: { @@ -3738,6 +3881,70 @@ namespace LCompilers { hash = builder->CreateTrunc(hash, llvm::Type::getInt32Ty(context)); return builder->CreateSRem(hash, capacity); } + case ASR::ttypeType::Byte: { + // Polynomial rolling hash function for strings + llvm::Value* null_char = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), + llvm::APInt(8, '\0')); + llvm::Value* p = llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 31)); + llvm::Value* m = llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 100000009)); + get_builder0() + hash_value = builder0.CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "hash_value"); + hash_iter = builder0.CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "hash_iter"); + polynomial_powers = builder0.CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "p_pow"); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 0)), + hash_value); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1)), + polynomial_powers); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 0)), + hash_iter); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value* i = LLVM::CreateLoad(*builder, hash_iter); + llvm::Value* c = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key, i)); + llvm::Value *cond = builder->CreateICmpNE(c, null_char); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + // for c in key: + // hash_value = (hash_value + (ord(c) + 1) * p_pow) % m + // p_pow = (p_pow * p) % m + llvm::Value* i = LLVM::CreateLoad(*builder, hash_iter); + llvm::Value* c = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key, i)); + llvm::Value* p_pow = LLVM::CreateLoad(*builder, polynomial_powers); + llvm::Value* hash = LLVM::CreateLoad(*builder, hash_value); + c = builder->CreateZExt(c, llvm::Type::getInt64Ty(context)); + c = builder->CreateAdd(c, llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1))); + c = builder->CreateMul(c, p_pow); + c = builder->CreateSRem(c, m); + hash = builder->CreateAdd(hash, c); + hash = builder->CreateSRem(hash, m); + LLVM::CreateStore(*builder, hash, hash_value); + p_pow = builder->CreateMul(p_pow, p); + p_pow = builder->CreateSRem(p_pow, m); + LLVM::CreateStore(*builder, p_pow, polynomial_powers); + i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1))); + LLVM::CreateStore(*builder, i, hash_iter); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + llvm::Value* hash = LLVM::CreateLoad(*builder, hash_value); + hash = builder->CreateTrunc(hash, llvm::Type::getInt32Ty(context)); + return builder->CreateSRem(hash, capacity); + } case ASR::ttypeType::Tuple: { llvm::Value* tuple_hash = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); ASR::Tuple_t* asr_tuple = ASR::down_cast(key_asr_type); @@ -5935,6 +6142,70 @@ namespace LCompilers { hash = builder->CreateTrunc(hash, llvm::Type::getInt32Ty(context)); return builder->CreateSRem(hash, capacity); } + case ASR::ttypeType::Byte: { + // Polynomial rolling hash function for bytes + llvm::Value* null_char = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), + llvm::APInt(8, '\0')); + llvm::Value* p = llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 31)); + llvm::Value* m = llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 100000009)); + get_builder0() + hash_value = builder0.CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "hash_value"); + hash_iter = builder0.CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "hash_iter"); + polynomial_powers = builder0.CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "p_pow"); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 0)), + hash_value); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1)), + polynomial_powers); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 0)), + hash_iter); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value* i = LLVM::CreateLoad(*builder, hash_iter); + llvm::Value* c = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(el, i)); + llvm::Value *cond = builder->CreateICmpNE(c, null_char); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + // for c in el: + // hash_value = (hash_value + (ord(c) + 1) * p_pow) % m + // p_pow = (p_pow * p) % m + llvm::Value* i = LLVM::CreateLoad(*builder, hash_iter); + llvm::Value* c = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(el, i)); + llvm::Value* p_pow = LLVM::CreateLoad(*builder, polynomial_powers); + llvm::Value* hash = LLVM::CreateLoad(*builder, hash_value); + c = builder->CreateZExt(c, llvm::Type::getInt64Ty(context)); + c = builder->CreateAdd(c, llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1))); + c = builder->CreateMul(c, p_pow); + c = builder->CreateSRem(c, m); + hash = builder->CreateAdd(hash, c); + hash = builder->CreateSRem(hash, m); + LLVM::CreateStore(*builder, hash, hash_value); + p_pow = builder->CreateMul(p_pow, p); + p_pow = builder->CreateSRem(p_pow, m); + LLVM::CreateStore(*builder, p_pow, polynomial_powers); + i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1))); + LLVM::CreateStore(*builder, i, hash_iter); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + llvm::Value* hash = LLVM::CreateLoad(*builder, hash_value); + hash = builder->CreateTrunc(hash, llvm::Type::getInt32Ty(context)); + return builder->CreateSRem(hash, capacity); + } case ASR::ttypeType::Tuple: { llvm::Value* tuple_hash = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); ASR::Tuple_t* asr_tuple = ASR::down_cast(el_asr_type); diff --git a/src/libasr/pass/global_stmts.cpp b/src/libasr/pass/global_stmts.cpp index 7fc1e8e6c4..9517ab5e8d 100644 --- a/src/libasr/pass/global_stmts.cpp +++ b/src/libasr/pass/global_stmts.cpp @@ -51,6 +51,7 @@ void pass_wrap_global_stmts(Allocator &al, (ASRUtils::expr_type(value)->type == ASR::ttypeType::Real) || (ASRUtils::expr_type(value)->type == ASR::ttypeType::Complex) || (ASRUtils::expr_type(value)->type == ASR::ttypeType::Character) || + (ASRUtils::expr_type(value)->type == ASR::ttypeType::Byte) || (ASRUtils::expr_type(value)->type == ASR::ttypeType::List) || (ASRUtils::expr_type(value)->type == ASR::ttypeType::Tuple) || (ASRUtils::expr_type(value)->type == ASR::ttypeType::StructType)) { diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index a44694d95d..eeb5275b9b 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -886,6 +886,9 @@ class CommonVisitor : public AST::BaseVisitor { } else if (var_annotation == "c64") { type = ASRUtils::TYPE(ASR::make_Complex_t(al, loc, 8)); type = ASRUtils::make_Array_t_util(al, loc, type, dims.p, dims.size(), abi, is_argument); + } else if (var_annotation == "bytes") { + type = ASRUtils::TYPE(ASR::make_Byte_t(al, loc, 1, -2, nullptr)); + type = ASRUtils::make_Array_t_util(al, loc, type, dims.p, dims.size(), abi, is_argument); } else if (var_annotation == "str") { type = ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)); type = ASRUtils::make_Array_t_util(al, loc, type, dims.p, dims.size(), abi, is_argument); @@ -2899,6 +2902,7 @@ class CommonVisitor : public AST::BaseVisitor { } else { type = ast_expr_to_asr_type(x.m_annotation->base.loc, *x.m_annotation, is_allocatable, is_const, true, abi); } + if (ASR::is_a(*type)) { ASR::FunctionType_t* fn_type = ASR::down_cast(type); handle_lambda_function_declaration(var_name, fn_type, x.m_value, x.base.base.loc); @@ -2956,6 +2960,7 @@ class CommonVisitor : public AST::BaseVisitor { } else { cast_helper(type, init_expr, init_expr->base.loc); } + if (!inside_struct || is_const) { process_variable_init_val(current_scope->get_symbol(var_name), x.base.base.loc, init_expr); @@ -3567,6 +3572,13 @@ class CommonVisitor : public AST::BaseVisitor { 1, s_size, nullptr)); tmp = ASR::make_StringConstant_t(al, x.base.base.loc, s, type); } + void visit_ConstantBytes(const AST::ConstantBytes_t &x) { + char *s = x.m_value; + size_t s_size = std::string(s).size(); + ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_Byte_t(al, x.base.base.loc, + 1, s_size, nullptr)); + tmp = ASR::make_BytesConstant_t(al, x.base.base.loc, s, type); + } void visit_ConstantBool(const AST::ConstantBool_t &x) { bool b = x.m_value; @@ -9132,7 +9144,6 @@ Result python_ast_to_asr(Allocator &al, LocationManager }; #endif } - return tu; }