From 5ad9b066f26510f25353fbcef6b76b4c103aee07 Mon Sep 17 00:00:00 2001 From: Henry Zongaro Date: Wed, 11 Sep 2024 07:35:39 -0700 Subject: [PATCH 1/2] Guard against overflow of String Builder Transformer estimate String Builder Transformer uses the result of getStringUTF8Length to estimate the StringBuilder buffer size needed to accommodate appending a constant String to a StringBuilder. That could overestimate the space required. This has been changed to use getStringLength instead, to use the actual lengths of constant String objects. A test has also been added to detect integer overflow of the capacity estimate, aborting the transformation, as StringBuilder. will throw a NegativeArraySizeException if the specified capacity is negative. Signed-off-by: Henry Zongaro --- .../optimizer/StringBuilderTransformer.cpp | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/runtime/compiler/optimizer/StringBuilderTransformer.cpp b/runtime/compiler/optimizer/StringBuilderTransformer.cpp index 0362004ce2f..0c5c8a68ebc 100644 --- a/runtime/compiler/optimizer/StringBuilderTransformer.cpp +++ b/runtime/compiler/optimizer/StringBuilderTransformer.cpp @@ -142,6 +142,15 @@ int32_t TR_StringBuilderTransformer::performOnBlock(TR::Block* block) { int32_t capacity = computeHeuristicStringBuilderInitCapacity(appendArguments); + // Guard against the possibility that the computed capacity has overflowed, + // as StringBuilder.(I) will throw a NegativeArraySizeException if the + // capacity argument is negative. It is extremely unlikely that the capacity + // calculation will overflow, but possible. + if (capacity < 0) + { + return 1; + } + if (performTransformation(comp(), "%sTransforming java/lang/StringBuilder.()V call at node [0x%p] to java/lang/StringBuilder.(I)V with capacity = %d\n", OPT_DETAILS, initNode, capacity)) { static const bool collectAppendStatistics = feGetEnv("TR_StringBuilderTransformerCollectAppendStatistics") != NULL; @@ -465,14 +474,23 @@ TR::Node* TR_StringBuilderTransformer::findStringBuilderChainedAppendArguments(T * The heuristic used in the non-constant case is hard coded per object type and was determined using the * statistics collection mechanisms implemented by this optimization. See TR_StringBuilderTransformer Environment * Variables section for more details. + * + * \returns A non-negative computed capacity or a negative value if the capacity estimate resulted in an integer overflow */ int32_t TR_StringBuilderTransformer::computeHeuristicStringBuilderInitCapacity(List >& appendArguments) { - int32_t capacity = 0; + uint32_t capacity = 0; ListIterator > iter(&appendArguments); - for (TR_Pair* pair = iter.getFirst(); pair != NULL; pair = iter.getNext()) + // Iterate over pairs of recognized methods and arguments that can be used + // to estimate the size of the StringBuilder buffer that will be needed. + // If the estimated capacity ever exceeds the maximum int32_t value, the + // calculation has overflowed, so halt early. + // + for (TR_Pair* pair = iter.getFirst(); + (pair != NULL) && (capacity <= std::numeric_limits::max()); + pair = iter.getNext()) { TR::Node* argument = pair->getKey(); @@ -595,7 +613,7 @@ int32_t TR_StringBuilderTransformer::computeHeuristicStringBuilderInitCapacity(L { uintptr_t stringObjectLocation = (uintptr_t)symbol->castToStaticSymbol()->getStaticAddress(); uintptr_t stringObject = comp()->fej9()->getStaticReferenceFieldAtAddress(stringObjectLocation); - capacity += comp()->fe()->getStringUTF8Length(stringObject); + capacity += comp()->fej9()->getStringLength(stringObject); break; } @@ -625,5 +643,8 @@ int32_t TR_StringBuilderTransformer::computeHeuristicStringBuilderInitCapacity(L } } - return capacity; + // If the loop has halted early because the value of capacity is greater than the + // int32_t value, casting the capacity to int32_t will yield a negative result, + // signalling that the capacity calculation has failed. + return (int32_t) capacity; } From 998b8098027a255d60861fcb120207c78339c6bb Mon Sep 17 00:00:00 2001 From: Henry Zongaro Date: Wed, 11 Sep 2024 07:37:46 -0700 Subject: [PATCH 2/2] Adjust expected return type for calls to getStringUTF8Length The return type of the JVM's getStringUTF8Length method has changed from IDATA to UDATA. This change adjusts the JIT compiler's uses of that method. In particular, the return type of TR_FrontEnd::getStringUTF8Length and its overriding methods changes from intptr_t to int32_t. Similarly, the bufferSize argument of TR_FrontEnd::getStringUTF8 becomes uintptr_t where it was int32_t. The motivation was that the length of a UTF-8 encoded String could be greater than 2^32 bytes, so the length could overflow on a 32-bit platform. All uses of getStringUTF8Length in the JIT involve signatures, descriptors, method names and class names, which will never be large enough to exceed the range of the int32_t type. Just to be cautious, however, this change includes an assertion test that the computed length is in range for the int32_t type, allowing for a maximum length of 2^31-2. That ensures that any code that then uses that length to allocate a buffer to contain the encoded String with a NUL terminator will not overflow a 32-bit signed integer representation for the length plus the NUL byte. This change also introduces a method, getStringUTF8UnabbreviatedLength, that returns a length value of type uint64_t if the JIT compiler ever needs to determine the UTF-8 encoded length of an arbitrary String. The method is currently unused. Finally, this change removes the VM_getStringUTF8Length JITServer message type, which is never used. Signed-off-by: Henry Zongaro --- .../control/JITClientCompilationThread.cpp | 13 ++------- runtime/compiler/env/VMJ9.cpp | 23 +++++++++++++--- runtime/compiler/env/VMJ9.h | 27 ++++++++++++++++--- runtime/compiler/env/VMJ9Server.cpp | 12 ++++++--- runtime/compiler/env/VMJ9Server.hpp | 3 ++- runtime/compiler/env/j9method.cpp | 4 +-- runtime/compiler/net/CommunicationStream.hpp | 2 +- runtime/compiler/net/MessageTypes.cpp | 1 - runtime/compiler/net/MessageTypes.hpp | 1 - 9 files changed, 58 insertions(+), 28 deletions(-) diff --git a/runtime/compiler/control/JITClientCompilationThread.cpp b/runtime/compiler/control/JITClientCompilationThread.cpp index 8abe816d966..8922adcb86b 100644 --- a/runtime/compiler/control/JITClientCompilationThread.cpp +++ b/runtime/compiler/control/JITClientCompilationThread.cpp @@ -630,15 +630,6 @@ handleServerMessage(JITServer::ClientStream *client, TR_J9VM *fe, JITServer::Mes client->write(response, fe->stackWalkerMaySkipFrames(method, clazz)); } break; - case MessageType::VM_getStringUTF8Length: - { - uintptr_t string = std::get<0>(client->getRecvData()); - { - TR::VMAccessCriticalSection getStringUTF8Length(fe); - client->write(response, fe->getStringUTF8Length(string)); - } - } - break; case MessageType::VM_classInitIsFinished: { TR_OpaqueClassBlock *clazz = std::get<0>(client->getRecvData()); @@ -2351,7 +2342,7 @@ handleServerMessage(JITServer::ClientStream *client, TR_J9VM *fe, JITServer::Mes "next", "Ljava/lang/invoke/MethodHandle;"), "type", "Ljava/lang/invoke/MethodType;"), "methodDescriptor", "Ljava/lang/String;"); - size_t methodDescriptorLength = fe->getStringUTF8Length(methodDescriptorRef); + intptr_t methodDescriptorLength = fe->getStringUTF8Length(methodDescriptorRef); char *methodDescriptor = (char*)alloca(methodDescriptorLength+1); fe->getStringUTF8(methodDescriptorRef, methodDescriptor, methodDescriptorLength+1); client->write(response, std::string(methodDescriptor, methodDescriptorLength)); @@ -2504,7 +2495,7 @@ handleServerMessage(JITServer::ClientStream *client, TR_J9VM *fe, JITServer::Mes int32_t numArgsPassToFinallyTarget = (int32_t)fe->getArrayLengthInElements(arguments); uintptr_t methodDescriptorRef = fe->getReferenceField(finallyType, "methodDescriptor", "Ljava/lang/String;"); - int methodDescriptorLength = fe->getStringUTF8Length(methodDescriptorRef); + intptr_t methodDescriptorLength = fe->getStringUTF8Length(methodDescriptorRef); char *methodDescriptor = (char*)alloca(methodDescriptorLength+1); fe->getStringUTF8(methodDescriptorRef, methodDescriptor, methodDescriptorLength+1); client->write(response, numArgsPassToFinallyTarget, std::string(methodDescriptor, methodDescriptorLength)); diff --git a/runtime/compiler/env/VMJ9.cpp b/runtime/compiler/env/VMJ9.cpp index f34ef659724..e6638c1559e 100644 --- a/runtime/compiler/env/VMJ9.cpp +++ b/runtime/compiler/env/VMJ9.cpp @@ -5537,18 +5537,33 @@ TR_J9VMBase::getStringCharacter(uintptr_t objectPointer, int32_t index) } } -intptr_t +int32_t TR_J9VMBase::getStringUTF8Length(uintptr_t objectPointer) { TR_ASSERT(haveAccess(), "Must have VM access to call getStringUTF8Length"); TR_ASSERT(objectPointer, "assertion failure"); - return vmThread()->javaVM->internalVMFunctions->getStringUTF8Length(vmThread(), (j9object_t)objectPointer); + uint64_t actualLength = vmThread()->javaVM->internalVMFunctions->getStringUTF8LengthTruncated(vmThread(), (j9object_t)objectPointer, INT64_MAX); + + // Fail if length+1 cannot be represented as an int32_t value. The extra byte accounts for + // any NUL terminator that might be needed in copying the UTF-8 encoded string into a buffer + TR_ASSERT_FATAL(actualLength+1 <= std::numeric_limits::max(), "UTF8-encoded String length of " UINT64_PRINTF_FORMAT " must be in the range permitted for type int32_t, also allowing for a NUL terminator.\n", actualLength); + + return (int32_t) actualLength; + } + + +uint64_t +TR_J9VMBase::getStringUTF8UnabbreviatedLength(uintptr_t objectPointer) + { + TR_ASSERT(haveAccess(), "Must have VM access to call getStringUTF8Length"); + TR_ASSERT(objectPointer, "assertion failure"); + return vmThread()->javaVM->internalVMFunctions->getStringUTF8LengthTruncated(vmThread(), (j9object_t)objectPointer, INT64_MAX); } char * -TR_J9VMBase::getStringUTF8(uintptr_t objectPointer, char *buffer, intptr_t bufferSize) +TR_J9VMBase::getStringUTF8(uintptr_t objectPointer, char *buffer, uintptr_t bufferSize) { - TR_ASSERT(haveAccess(), "Must have VM access to call getStringAscii"); + TR_ASSERT(haveAccess(), "Must have VM access to call getStringUTF8"); vmThread()->javaVM->internalVMFunctions->copyStringToUTF8Helper(vmThread(), (j9object_t)objectPointer, J9_STR_NULL_TERMINATE_RESULT, 0, J9VMJAVALANGSTRING_LENGTH(vmThread(), objectPointer), (U_8*)buffer, (UDATA)bufferSize); diff --git a/runtime/compiler/env/VMJ9.h b/runtime/compiler/env/VMJ9.h index 6d8e2847e90..e944b37a0c8 100644 --- a/runtime/compiler/env/VMJ9.h +++ b/runtime/compiler/env/VMJ9.h @@ -1117,7 +1117,7 @@ class TR_J9VMBase : public TR_FrontEnd */ virtual bool isChangesCurrentThread(TR_ResolvedMethod *method); - /* + /** * \brief * tell whether it's possible to dereference a field given the field symbol at compile time * @@ -1138,8 +1138,29 @@ class TR_J9VMBase : public TR_FrontEnd virtual bool isJavaLangObject(TR_OpaqueClassBlock *clazz); virtual int32_t getStringLength(uintptr_t objectPointer); virtual uint16_t getStringCharacter(uintptr_t objectPointer, int32_t index); - virtual intptr_t getStringUTF8Length(uintptr_t objectPointer); - virtual char *getStringUTF8 (uintptr_t objectPointer, char *buffer, intptr_t bufferSize); + + /** + * \brief Returns the number of UTF-8 encoded bytes needed to represent a Java String object. + * The number of bytes needed to UTF-8 encode the String is representable as + * a \c uint64_t, in general, but this method returns a length of type \c int32_t. + * If the length might exceed the range of \c int32_t, use + * \ref getStringUTF8UnabbreviatedLength instead. + * + * \param[in] objectPointer A pointer to a Java String object + * + * \return The number of UTF-8 encoded bytes needed to represent the String + */ + virtual int32_t getStringUTF8Length(uintptr_t objectPointer); + + /** + * \brief Returns the number of UTF-8 encoded bytes needed to represent a Java String object. + * + * \param[in] objectPointer A pointer to a Java String object + * + * \return The number of UTF-8 encoded bytes needed to represent the String + */ + virtual uint64_t getStringUTF8UnabbreviatedLength(uintptr_t objectPointer); + virtual char *getStringUTF8(uintptr_t objectPointer, char *buffer, uintptr_t bufferSize); virtual uint32_t getVarHandleHandleTableOffset(TR::Compilation *); diff --git a/runtime/compiler/env/VMJ9Server.cpp b/runtime/compiler/env/VMJ9Server.cpp index 382d7c6176a..b4f5962d02b 100644 --- a/runtime/compiler/env/VMJ9Server.cpp +++ b/runtime/compiler/env/VMJ9Server.cpp @@ -1073,12 +1073,16 @@ TR_J9ServerVM::getHostClass(TR_OpaqueClassBlock *clazz) return hostClass; } -intptr_t +int32_t TR_J9ServerVM::getStringUTF8Length(uintptr_t objectPointer) { - JITServer::ServerStream *stream = _compInfoPT->getMethodBeingCompiled()->_stream; - stream->write(JITServer::MessageType::VM_getStringUTF8Length, objectPointer); - return std::get<0>(stream->read()); + TR_ASSERT_FATAL(false, "getStringUTF8Length(uintptr_t) should not be called by JITServer"); + } + +uint64_t +TR_J9ServerVM::getStringUTF8UnabbreviatedLength(uintptr_t objectPointer) + { + TR_ASSERT_FATAL(false, "getStringUTF8UnabbreviatedLength(uintptr_t) should not be called by JITServer"); } bool diff --git a/runtime/compiler/env/VMJ9Server.hpp b/runtime/compiler/env/VMJ9Server.hpp index bdae274a4e6..b2027f51f45 100644 --- a/runtime/compiler/env/VMJ9Server.hpp +++ b/runtime/compiler/env/VMJ9Server.hpp @@ -121,7 +121,8 @@ class TR_J9ServerVM: public TR_J9VM virtual bool hasFinalFieldsInClass(TR_OpaqueClassBlock *clazz) override; virtual const char *sampleSignature(TR_OpaqueMethodBlock * aMethod, char *buf, int32_t bufLen, TR_Memory *memory) override; virtual TR_OpaqueClassBlock * getHostClass(TR_OpaqueClassBlock *clazzOffset) override; - virtual intptr_t getStringUTF8Length(uintptr_t objectPointer) override; + virtual int32_t getStringUTF8Length(uintptr_t objectPointer) override; + virtual uint64_t getStringUTF8UnabbreviatedLength(uintptr_t objectPointer) override; virtual bool classInitIsFinished(TR_OpaqueClassBlock *) override; virtual int32_t getNewArrayTypeFromClass(TR_OpaqueClassBlock *clazz) override; virtual TR_OpaqueClassBlock *getClassFromNewArrayType(int32_t arrayType) override; diff --git a/runtime/compiler/env/j9method.cpp b/runtime/compiler/env/j9method.cpp index 2de0aad91b8..fae24d5bcfc 100644 --- a/runtime/compiler/env/j9method.cpp +++ b/runtime/compiler/env/j9method.cpp @@ -8413,7 +8413,7 @@ TR_J9ByteCodeIlGenerator::runFEMacro(TR::SymbolReference *symRef) uintptr_t methodHandle; uintptr_t methodDescriptorRef; - intptr_t methodDescriptorLength; + uintptr_t methodDescriptorLength; #if defined(J9VM_OPT_JITSERVER) if (comp()->isOutOfProcessCompilation()) @@ -9262,7 +9262,7 @@ TR_J9ByteCodeIlGenerator::runFEMacro(TR::SymbolReference *symRef) numArgsPassToFinallyTarget = (int32_t)fej9->getArrayLengthInElements(arguments); uintptr_t methodDescriptorRef = fej9->getReferenceField(finallyType, "methodDescriptor", "Ljava/lang/String;"); - int methodDescriptorLength = fej9->getStringUTF8Length(methodDescriptorRef); + intptr_t methodDescriptorLength = fej9->getStringUTF8Length(methodDescriptorRef); methodDescriptor = (char*)alloca(methodDescriptorLength+1); fej9->getStringUTF8(methodDescriptorRef, methodDescriptor, methodDescriptorLength+1); } diff --git a/runtime/compiler/net/CommunicationStream.hpp b/runtime/compiler/net/CommunicationStream.hpp index ebfb9d3c231..b8050d67222 100644 --- a/runtime/compiler/net/CommunicationStream.hpp +++ b/runtime/compiler/net/CommunicationStream.hpp @@ -129,7 +129,7 @@ class CommunicationStream // likely to lose an increment when merging/rebasing/etc. // static const uint8_t MAJOR_NUMBER = 1; - static const uint16_t MINOR_NUMBER = 78; // ID: DGwBSxx9FLiSwWTdQCIn + static const uint16_t MINOR_NUMBER = 79; // ID: Su+UK1Q5oJlgUkWIBA6f static const uint8_t PATCH_NUMBER = 0; static uint32_t CONFIGURATION_FLAGS; diff --git a/runtime/compiler/net/MessageTypes.cpp b/runtime/compiler/net/MessageTypes.cpp index 4715a258704..1003cf35bdc 100644 --- a/runtime/compiler/net/MessageTypes.cpp +++ b/runtime/compiler/net/MessageTypes.cpp @@ -121,7 +121,6 @@ const char *messageNames[] = "VM_getObjectClassFromKnownObjectIndexJLClass", "VM_getObjectClassInfoFromObjectReferenceLocation", "VM_stackWalkerMaySkipFrames", - "VM_getStringUTF8Length", "VM_classInitIsFinished", "VM_getClassFromNewArrayType", "VM_getArrayClassFromComponentClass", diff --git a/runtime/compiler/net/MessageTypes.hpp b/runtime/compiler/net/MessageTypes.hpp index f90e791aa93..dad54bce271 100644 --- a/runtime/compiler/net/MessageTypes.hpp +++ b/runtime/compiler/net/MessageTypes.hpp @@ -130,7 +130,6 @@ enum MessageType : uint16_t VM_getObjectClassFromKnownObjectIndexJLClass, VM_getObjectClassInfoFromObjectReferenceLocation, VM_stackWalkerMaySkipFrames, - VM_getStringUTF8Length, VM_classInitIsFinished, VM_getClassFromNewArrayType, VM_getArrayClassFromComponentClass,