Skip to content

Commit 30b622d

Browse files
authored
Merge pull request #79773 from mikeash/metadatareader-better-descriptor-sizing
[Reflection] Fix undersized reads in MetadataReader::readContextDescriptor.
2 parents 0997a7e + 91e7daa commit 30b622d

File tree

4 files changed

+114
-100
lines changed

4 files changed

+114
-100
lines changed

include/swift/ABI/GenericContext.h

+8
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,14 @@ class TrailingGenericContextObjects<TargetSelf<Runtime>,
711711
getGenericValueDescriptors().data()};
712712
}
713713

714+
static size_t trailingTypeCount() {
715+
return TrailingObjects::trailingTypeCount();
716+
}
717+
718+
size_t sizeWithTrailingTypeCount(size_t n) const {
719+
return TrailingObjects::sizeWithTrailingTypeCount(n);
720+
}
721+
714722
protected:
715723
size_t numTrailingObjects(OverloadToken<GenericContextHeaderType>) const {
716724
return asSelf()->isGeneric() ? 1 : 0;

include/swift/ABI/Metadata.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -3168,7 +3168,7 @@ struct swift_ptrauth_struct_context_descriptor(AnonymousContextDescriptor)
31683168
using TrailingGenericContextObjects::numTrailingObjects;
31693169

31703170
size_t numTrailingObjects(OverloadToken<MangledContextName>) const {
3171-
return this->hasMangledNam() ? 1 : 0;
3171+
return this->hasMangledName() ? 1 : 0;
31723172
}
31733173

31743174
public:

include/swift/ABI/TrailingObjects.h

+33
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,21 @@ class TrailingObjectsImpl<Align, BaseTy, TopTrailingObj, PrevTy, NextTy,
207207
sizeof(NextTy) * Count1,
208208
MoreCounts...);
209209
}
210+
211+
// Helper function for TrailingObjects::sizeWithTrailingTypeCount. This
212+
// recurses to superclasses until n reaches zero, then computes the size of
213+
// the object up to that point.
214+
static size_t sizeWithTrailingTypeCountImpl(const BaseTy *Obj, size_t n) {
215+
if (n > 0)
216+
return ParentType::sizeWithTrailingTypeCountImpl(Obj, n - 1);
217+
218+
auto *Ptr = getTrailingObjectsImpl(
219+
Obj, TrailingObjectsBase::OverloadToken<NextTy>());
220+
auto Count = TopTrailingObj::callNumTrailingObjects(
221+
Obj, TrailingObjectsBase::OverloadToken<NextTy>());
222+
auto *End = Ptr + Count;
223+
return (const char *)End - (const char *)Obj;
224+
}
210225
};
211226

212227
// The base case of the TrailingObjectsImpl inheritance recursion,
@@ -225,6 +240,10 @@ class TrailingObjectsImpl<Align, BaseTy, TopTrailingObj, PrevTy>
225240
}
226241

227242
template <bool CheckAlignment> static void verifyTrailingObjectsAlignment() {}
243+
244+
static size_t sizeWithTrailingTypeCountImpl(const BaseTy *Obj, size_t n) {
245+
return 0;
246+
}
228247
};
229248

230249
} // end namespace trailing_objects_internal
@@ -386,6 +405,20 @@ class swift_ptrauth_struct_derived(BaseTy) TrailingObjects
386405

387406
BaseTy *const p;
388407
};
408+
409+
// Get the number of trailing types in this TrailingObjects specialization.
410+
static size_t trailingTypeCount() { return sizeof...(TrailingTys); }
411+
412+
// Get the size of the object including trailing objects through index N. This
413+
// allows working out the size of a TrailingObjects subclass incrementally,
414+
// by calling this repeatedly starting from 0. This is needed for remote
415+
// inspection, which needs to figure out how much memory to read just from the
416+
// contents of the object. It can repeatedly read a prefix until it has the
417+
// whole thing.
418+
size_t sizeWithTrailingTypeCount(size_t n) const {
419+
return ParentType::sizeWithTrailingTypeCountImpl(
420+
static_cast<const BaseTy *>(this), n);
421+
}
389422
};
390423

391424
} // end namespace ABI

include/swift/Remote/MetadataReader.h

+72-99
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#define SWIFT_REMOTE_METADATAREADER_H
1919

2020

21+
#include "swift/ABI/Metadata.h"
2122
#include "swift/Runtime/Metadata.h"
2223
#include "swift/Remote/MemoryReader.h"
2324
#include "swift/Demangling/Demangler.h"
@@ -1451,135 +1452,107 @@ class MetadataReader {
14511452
if (address == 0)
14521453
return nullptr;
14531454

1455+
auto remoteAddress = RemoteAddress(address);
1456+
auto ptr = Reader->readBytes(remoteAddress,
1457+
sizeof(TargetContextDescriptor<Runtime>));
1458+
if (!ptr)
1459+
return nullptr;
1460+
14541461
auto cached = ContextDescriptorCache.find(address);
14551462
if (cached != ContextDescriptorCache.end())
14561463
return ContextDescriptorRef(
14571464
address, reinterpret_cast<const TargetContextDescriptor<Runtime> *>(
14581465
cached->second.get()));
14591466

1460-
// Read the flags to figure out how much space we should read.
1461-
ContextDescriptorFlags flags;
1462-
if (!Reader->readBytes(RemoteAddress(address), (uint8_t*)&flags,
1463-
sizeof(flags)))
1464-
return nullptr;
1465-
1466-
TypeContextDescriptorFlags typeFlags(flags.getKindSpecificFlags());
1467-
uint64_t baseSize = 0;
1468-
uint64_t genericHeaderSize = sizeof(GenericContextDescriptorHeader);
1469-
uint64_t metadataInitSize = 0;
1470-
bool hasVTable = false;
1471-
1472-
auto readMetadataInitSize = [&]() -> unsigned {
1473-
switch (typeFlags.getMetadataInitialization()) {
1474-
case TypeContextDescriptorFlags::NoMetadataInitialization:
1475-
return 0;
1476-
case TypeContextDescriptorFlags::SingletonMetadataInitialization:
1477-
// FIXME: classes
1478-
return sizeof(TargetSingletonMetadataInitialization<Runtime>);
1479-
case TypeContextDescriptorFlags::ForeignMetadataInitialization:
1480-
return sizeof(TargetForeignMetadataInitialization<Runtime>);
1481-
}
1482-
return 0;
1483-
};
1484-
1485-
switch (flags.getKind()) {
1467+
bool success = false;
1468+
switch (
1469+
reinterpret_cast<const TargetContextDescriptor<Runtime> *>(ptr.get())
1470+
->getKind()) {
14861471
case ContextDescriptorKind::Module:
1487-
baseSize = sizeof(TargetModuleContextDescriptor<Runtime>);
1472+
ptr = Reader->readBytes(remoteAddress,
1473+
sizeof(TargetModuleContextDescriptor<Runtime>));
1474+
success = ptr != nullptr;
14881475
break;
1489-
// TODO: Should we include trailing generic arguments in this load?
14901476
case ContextDescriptorKind::Extension:
1491-
baseSize = sizeof(TargetExtensionContextDescriptor<Runtime>);
1477+
success =
1478+
readFullContextDescriptor<TargetExtensionContextDescriptor<Runtime>>(
1479+
remoteAddress, ptr);
14921480
break;
14931481
case ContextDescriptorKind::Anonymous:
1494-
baseSize = sizeof(TargetAnonymousContextDescriptor<Runtime>);
1495-
if (AnonymousContextDescriptorFlags(flags.getKindSpecificFlags())
1496-
.hasMangledName()) {
1497-
metadataInitSize = sizeof(TargetMangledContextName<Runtime>);
1498-
}
1482+
success =
1483+
readFullContextDescriptor<TargetAnonymousContextDescriptor<Runtime>>(
1484+
remoteAddress, ptr);
14991485
break;
15001486
case ContextDescriptorKind::Class:
1501-
baseSize = sizeof(TargetClassDescriptor<Runtime>);
1502-
genericHeaderSize = sizeof(TypeGenericContextDescriptorHeader);
1503-
hasVTable = typeFlags.class_hasVTable();
1504-
metadataInitSize = readMetadataInitSize();
1487+
success = readFullContextDescriptor<TargetClassDescriptor<Runtime>>(
1488+
remoteAddress, ptr);
15051489
break;
15061490
case ContextDescriptorKind::Enum:
1507-
baseSize = sizeof(TargetEnumDescriptor<Runtime>);
1508-
genericHeaderSize = sizeof(TypeGenericContextDescriptorHeader);
1509-
metadataInitSize = readMetadataInitSize();
1491+
success = readFullContextDescriptor<TargetEnumDescriptor<Runtime>>(
1492+
remoteAddress, ptr);
15101493
break;
15111494
case ContextDescriptorKind::Struct:
1512-
baseSize = sizeof(TargetStructDescriptor<Runtime>);
1513-
genericHeaderSize = sizeof(TypeGenericContextDescriptorHeader);
1514-
metadataInitSize = readMetadataInitSize();
1495+
success = readFullContextDescriptor<TargetStructDescriptor<Runtime>>(
1496+
remoteAddress, ptr);
15151497
break;
15161498
case ContextDescriptorKind::Protocol:
1517-
baseSize = sizeof(TargetProtocolDescriptor<Runtime>);
1499+
success = readFullContextDescriptor<TargetProtocolDescriptor<Runtime>>(
1500+
remoteAddress, ptr);
15181501
break;
15191502
case ContextDescriptorKind::OpaqueType:
1520-
baseSize = sizeof(TargetOpaqueTypeDescriptor<Runtime>);
1521-
metadataInitSize =
1522-
sizeof(typename Runtime::template RelativeDirectPointer<const char>)
1523-
* flags.getKindSpecificFlags();
1503+
success = readFullContextDescriptor<TargetOpaqueTypeDescriptor<Runtime>>(
1504+
remoteAddress, ptr);
15241505
break;
15251506
default:
15261507
// We don't know about this kind of context.
15271508
return nullptr;
15281509
}
1529-
1530-
// Determine the full size of the descriptor. This is reimplementing a fair
1531-
// bit of TrailingObjects but for out-of-process; maybe there's a way to
1532-
// factor the layout stuff out...
1533-
uint64_t genericsSize = 0;
1534-
if (flags.isGeneric()) {
1535-
GenericContextDescriptorHeader header;
1536-
auto headerAddr = address
1537-
+ baseSize
1538-
+ genericHeaderSize
1539-
- sizeof(header);
1540-
1541-
if (!Reader->readBytes(RemoteAddress(headerAddr),
1542-
(uint8_t*)&header, sizeof(header)))
1543-
return nullptr;
1544-
1545-
genericsSize = genericHeaderSize
1546-
+ (header.NumParams + 3u & ~3u)
1547-
+ header.NumRequirements
1548-
* sizeof(TargetGenericRequirementDescriptor<Runtime>);
1549-
}
1550-
1551-
uint64_t vtableSize = 0;
1552-
if (hasVTable) {
1553-
TargetVTableDescriptorHeader<Runtime> header;
1554-
auto headerAddr = address
1555-
+ baseSize
1556-
+ genericsSize
1557-
+ metadataInitSize;
1558-
1559-
if (!Reader->readBytes(RemoteAddress(headerAddr),
1560-
(uint8_t*)&header, sizeof(header)))
1561-
return nullptr;
1562-
1563-
vtableSize = sizeof(header)
1564-
+ header.VTableSize * sizeof(TargetMethodDescriptor<Runtime>);
1565-
}
1566-
1567-
uint64_t size = baseSize + genericsSize + metadataInitSize + vtableSize;
1568-
if (size > MaxMetadataSize)
1569-
return nullptr;
1570-
auto readResult = Reader->readBytes(RemoteAddress(address), size);
1571-
if (!readResult)
1510+
if (!success)
15721511
return nullptr;
15731512

1574-
auto descriptor =
1575-
reinterpret_cast<const TargetContextDescriptor<Runtime> *>(
1576-
readResult.get());
1577-
1578-
ContextDescriptorCache.insert(
1579-
std::make_pair(address, std::move(readResult)));
1513+
auto *descriptor =
1514+
reinterpret_cast<const TargetContextDescriptor<Runtime> *>(ptr.get());
1515+
ContextDescriptorCache.insert(std::make_pair(address, std::move(ptr)));
15801516
return ContextDescriptorRef(address, descriptor);
15811517
}
1582-
1518+
1519+
template <typename DescriptorTy>
1520+
bool readFullContextDescriptor(RemoteAddress address,
1521+
MemoryReader::ReadBytesResult &ptr) {
1522+
// Read the full base descriptor if it's bigger than what we have so far.
1523+
if (sizeof(DescriptorTy) > sizeof(TargetContextDescriptor<Runtime>)) {
1524+
ptr = Reader->readObj<DescriptorTy>(address);
1525+
if (!ptr)
1526+
return false;
1527+
}
1528+
1529+
// We don't know how much memory we need to read to get all the trailing
1530+
// objects, but we need to read the memory to figure out how much memory we
1531+
// need to read. Handle this by reading incrementally.
1532+
//
1533+
// We rely on the fact that each trailing object's count depends only on
1534+
// that comes before it. If we've read the first N trailing objects, then we
1535+
// can safely compute the size with N+1 trailing objects. If that size is
1536+
// bigger than what we've read so far, re-read the descriptor with the new
1537+
// size. Once we've walked through all the trailing objects, we've read
1538+
// everything.
1539+
1540+
size_t sizeSoFar = sizeof(DescriptorTy);
1541+
1542+
for (size_t i = 0; i < DescriptorTy::trailingTypeCount(); i++) {
1543+
const DescriptorTy *descriptorSoFar =
1544+
reinterpret_cast<const DescriptorTy *>(ptr.get());
1545+
size_t thisSize = descriptorSoFar->sizeWithTrailingTypeCount(i);
1546+
if (thisSize > sizeSoFar) {
1547+
ptr = Reader->readBytes(address, thisSize);
1548+
if (!ptr)
1549+
return false;
1550+
sizeSoFar = thisSize;
1551+
}
1552+
}
1553+
return true;
1554+
}
1555+
15831556
/// Demangle the entity represented by a symbolic reference to a given symbol name.
15841557
Demangle::NodePointer
15851558
buildContextManglingForSymbol(StringRef symbol, Demangler &dem) {

0 commit comments

Comments
 (0)