Skip to content

Commit 91e7daa

Browse files
committed
[Reflection] Fix undersized reads in MetadataReader::readContextDescriptor.
Determining a descriptor's size requires reading its contents, but reading its contents from out of process requires knowing its size. Build up the size incrementally by walking over the TrailingObjects. Take advantage of the fact that each trailing object's presence/count depends only on data that comes before it. This allows us to read prefixes that we gradually expand until we've covered the whole thing. Add calls to TrailingObjects to allow iterating over the prefix sizes, and modify readContextDescriptor to use them. This replaces the old code which attempted to determine the descriptor size in an ad-hoc fashion that didn't always get it right. rdar://146006006
1 parent a172489 commit 91e7daa

File tree

4 files changed

+114
-100
lines changed

4 files changed

+114
-100
lines changed

include/swift/ABI/GenericContext.h

+8
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,14 @@ class TrailingGenericContextObjects<TargetSelf<Runtime>,
711711
getGenericValueDescriptors().data()};
712712
}
713713

714+
static size_t trailingTypeCount() {
715+
return TrailingObjects::trailingTypeCount();
716+
}
717+
718+
size_t sizeWithTrailingTypeCount(size_t n) const {
719+
return TrailingObjects::sizeWithTrailingTypeCount(n);
720+
}
721+
714722
protected:
715723
size_t numTrailingObjects(OverloadToken<GenericContextHeaderType>) const {
716724
return asSelf()->isGeneric() ? 1 : 0;

include/swift/ABI/Metadata.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -3168,7 +3168,7 @@ struct swift_ptrauth_struct_context_descriptor(AnonymousContextDescriptor)
31683168
using TrailingGenericContextObjects::numTrailingObjects;
31693169

31703170
size_t numTrailingObjects(OverloadToken<MangledContextName>) const {
3171-
return this->hasMangledNam() ? 1 : 0;
3171+
return this->hasMangledName() ? 1 : 0;
31723172
}
31733173

31743174
public:

include/swift/ABI/TrailingObjects.h

+33
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,21 @@ class TrailingObjectsImpl<Align, BaseTy, TopTrailingObj, PrevTy, NextTy,
207207
sizeof(NextTy) * Count1,
208208
MoreCounts...);
209209
}
210+
211+
// Helper function for TrailingObjects::sizeWithTrailingTypeCount. This
212+
// recurses to superclasses until n reaches zero, then computes the size of
213+
// the object up to that point.
214+
static size_t sizeWithTrailingTypeCountImpl(const BaseTy *Obj, size_t n) {
215+
if (n > 0)
216+
return ParentType::sizeWithTrailingTypeCountImpl(Obj, n - 1);
217+
218+
auto *Ptr = getTrailingObjectsImpl(
219+
Obj, TrailingObjectsBase::OverloadToken<NextTy>());
220+
auto Count = TopTrailingObj::callNumTrailingObjects(
221+
Obj, TrailingObjectsBase::OverloadToken<NextTy>());
222+
auto *End = Ptr + Count;
223+
return (const char *)End - (const char *)Obj;
224+
}
210225
};
211226

212227
// The base case of the TrailingObjectsImpl inheritance recursion,
@@ -225,6 +240,10 @@ class TrailingObjectsImpl<Align, BaseTy, TopTrailingObj, PrevTy>
225240
}
226241

227242
template <bool CheckAlignment> static void verifyTrailingObjectsAlignment() {}
243+
244+
static size_t sizeWithTrailingTypeCountImpl(const BaseTy *Obj, size_t n) {
245+
return 0;
246+
}
228247
};
229248

230249
} // end namespace trailing_objects_internal
@@ -386,6 +405,20 @@ class swift_ptrauth_struct_derived(BaseTy) TrailingObjects
386405

387406
BaseTy *const p;
388407
};
408+
409+
// Get the number of trailing types in this TrailingObjects specialization.
410+
static size_t trailingTypeCount() { return sizeof...(TrailingTys); }
411+
412+
// Get the size of the object including trailing objects through index N. This
413+
// allows working out the size of a TrailingObjects subclass incrementally,
414+
// by calling this repeatedly starting from 0. This is needed for remote
415+
// inspection, which needs to figure out how much memory to read just from the
416+
// contents of the object. It can repeatedly read a prefix until it has the
417+
// whole thing.
418+
size_t sizeWithTrailingTypeCount(size_t n) const {
419+
return ParentType::sizeWithTrailingTypeCountImpl(
420+
static_cast<const BaseTy *>(this), n);
421+
}
389422
};
390423

391424
} // end namespace ABI

include/swift/Remote/MetadataReader.h

+72-99
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#define SWIFT_REMOTE_METADATAREADER_H
1919

2020

21+
#include "swift/ABI/Metadata.h"
2122
#include "swift/Runtime/Metadata.h"
2223
#include "swift/Remote/MemoryReader.h"
2324
#include "swift/Demangling/Demangler.h"
@@ -1451,135 +1452,107 @@ class MetadataReader {
14511452
if (address == 0)
14521453
return nullptr;
14531454

1455+
auto remoteAddress = RemoteAddress(address);
1456+
auto ptr = Reader->readBytes(remoteAddress,
1457+
sizeof(TargetContextDescriptor<Runtime>));
1458+
if (!ptr)
1459+
return nullptr;
1460+
14541461
auto cached = ContextDescriptorCache.find(address);
14551462
if (cached != ContextDescriptorCache.end())
14561463
return ContextDescriptorRef(
14571464
address, reinterpret_cast<const TargetContextDescriptor<Runtime> *>(
14581465
cached->second.get()));
14591466

1460-
// Read the flags to figure out how much space we should read.
1461-
ContextDescriptorFlags flags;
1462-
if (!Reader->readBytes(RemoteAddress(address), (uint8_t*)&flags,
1463-
sizeof(flags)))
1464-
return nullptr;
1465-
1466-
TypeContextDescriptorFlags typeFlags(flags.getKindSpecificFlags());
1467-
uint64_t baseSize = 0;
1468-
uint64_t genericHeaderSize = sizeof(GenericContextDescriptorHeader);
1469-
uint64_t metadataInitSize = 0;
1470-
bool hasVTable = false;
1471-
1472-
auto readMetadataInitSize = [&]() -> unsigned {
1473-
switch (typeFlags.getMetadataInitialization()) {
1474-
case TypeContextDescriptorFlags::NoMetadataInitialization:
1475-
return 0;
1476-
case TypeContextDescriptorFlags::SingletonMetadataInitialization:
1477-
// FIXME: classes
1478-
return sizeof(TargetSingletonMetadataInitialization<Runtime>);
1479-
case TypeContextDescriptorFlags::ForeignMetadataInitialization:
1480-
return sizeof(TargetForeignMetadataInitialization<Runtime>);
1481-
}
1482-
return 0;
1483-
};
1484-
1485-
switch (flags.getKind()) {
1467+
bool success = false;
1468+
switch (
1469+
reinterpret_cast<const TargetContextDescriptor<Runtime> *>(ptr.get())
1470+
->getKind()) {
14861471
case ContextDescriptorKind::Module:
1487-
baseSize = sizeof(TargetModuleContextDescriptor<Runtime>);
1472+
ptr = Reader->readBytes(remoteAddress,
1473+
sizeof(TargetModuleContextDescriptor<Runtime>));
1474+
success = ptr != nullptr;
14881475
break;
1489-
// TODO: Should we include trailing generic arguments in this load?
14901476
case ContextDescriptorKind::Extension:
1491-
baseSize = sizeof(TargetExtensionContextDescriptor<Runtime>);
1477+
success =
1478+
readFullContextDescriptor<TargetExtensionContextDescriptor<Runtime>>(
1479+
remoteAddress, ptr);
14921480
break;
14931481
case ContextDescriptorKind::Anonymous:
1494-
baseSize = sizeof(TargetAnonymousContextDescriptor<Runtime>);
1495-
if (AnonymousContextDescriptorFlags(flags.getKindSpecificFlags())
1496-
.hasMangledName()) {
1497-
metadataInitSize = sizeof(TargetMangledContextName<Runtime>);
1498-
}
1482+
success =
1483+
readFullContextDescriptor<TargetAnonymousContextDescriptor<Runtime>>(
1484+
remoteAddress, ptr);
14991485
break;
15001486
case ContextDescriptorKind::Class:
1501-
baseSize = sizeof(TargetClassDescriptor<Runtime>);
1502-
genericHeaderSize = sizeof(TypeGenericContextDescriptorHeader);
1503-
hasVTable = typeFlags.class_hasVTable();
1504-
metadataInitSize = readMetadataInitSize();
1487+
success = readFullContextDescriptor<TargetClassDescriptor<Runtime>>(
1488+
remoteAddress, ptr);
15051489
break;
15061490
case ContextDescriptorKind::Enum:
1507-
baseSize = sizeof(TargetEnumDescriptor<Runtime>);
1508-
genericHeaderSize = sizeof(TypeGenericContextDescriptorHeader);
1509-
metadataInitSize = readMetadataInitSize();
1491+
success = readFullContextDescriptor<TargetEnumDescriptor<Runtime>>(
1492+
remoteAddress, ptr);
15101493
break;
15111494
case ContextDescriptorKind::Struct:
1512-
baseSize = sizeof(TargetStructDescriptor<Runtime>);
1513-
genericHeaderSize = sizeof(TypeGenericContextDescriptorHeader);
1514-
metadataInitSize = readMetadataInitSize();
1495+
success = readFullContextDescriptor<TargetStructDescriptor<Runtime>>(
1496+
remoteAddress, ptr);
15151497
break;
15161498
case ContextDescriptorKind::Protocol:
1517-
baseSize = sizeof(TargetProtocolDescriptor<Runtime>);
1499+
success = readFullContextDescriptor<TargetProtocolDescriptor<Runtime>>(
1500+
remoteAddress, ptr);
15181501
break;
15191502
case ContextDescriptorKind::OpaqueType:
1520-
baseSize = sizeof(TargetOpaqueTypeDescriptor<Runtime>);
1521-
metadataInitSize =
1522-
sizeof(typename Runtime::template RelativeDirectPointer<const char>)
1523-
* flags.getKindSpecificFlags();
1503+
success = readFullContextDescriptor<TargetOpaqueTypeDescriptor<Runtime>>(
1504+
remoteAddress, ptr);
15241505
break;
15251506
default:
15261507
// We don't know about this kind of context.
15271508
return nullptr;
15281509
}
1529-
1530-
// Determine the full size of the descriptor. This is reimplementing a fair
1531-
// bit of TrailingObjects but for out-of-process; maybe there's a way to
1532-
// factor the layout stuff out...
1533-
uint64_t genericsSize = 0;
1534-
if (flags.isGeneric()) {
1535-
GenericContextDescriptorHeader header;
1536-
auto headerAddr = address
1537-
+ baseSize
1538-
+ genericHeaderSize
1539-
- sizeof(header);
1540-
1541-
if (!Reader->readBytes(RemoteAddress(headerAddr),
1542-
(uint8_t*)&header, sizeof(header)))
1543-
return nullptr;
1544-
1545-
genericsSize = genericHeaderSize
1546-
+ (header.NumParams + 3u & ~3u)
1547-
+ header.NumRequirements
1548-
* sizeof(TargetGenericRequirementDescriptor<Runtime>);
1549-
}
1550-
1551-
uint64_t vtableSize = 0;
1552-
if (hasVTable) {
1553-
TargetVTableDescriptorHeader<Runtime> header;
1554-
auto headerAddr = address
1555-
+ baseSize
1556-
+ genericsSize
1557-
+ metadataInitSize;
1558-
1559-
if (!Reader->readBytes(RemoteAddress(headerAddr),
1560-
(uint8_t*)&header, sizeof(header)))
1561-
return nullptr;
1562-
1563-
vtableSize = sizeof(header)
1564-
+ header.VTableSize * sizeof(TargetMethodDescriptor<Runtime>);
1565-
}
1566-
1567-
uint64_t size = baseSize + genericsSize + metadataInitSize + vtableSize;
1568-
if (size > MaxMetadataSize)
1569-
return nullptr;
1570-
auto readResult = Reader->readBytes(RemoteAddress(address), size);
1571-
if (!readResult)
1510+
if (!success)
15721511
return nullptr;
15731512

1574-
auto descriptor =
1575-
reinterpret_cast<const TargetContextDescriptor<Runtime> *>(
1576-
readResult.get());
1577-
1578-
ContextDescriptorCache.insert(
1579-
std::make_pair(address, std::move(readResult)));
1513+
auto *descriptor =
1514+
reinterpret_cast<const TargetContextDescriptor<Runtime> *>(ptr.get());
1515+
ContextDescriptorCache.insert(std::make_pair(address, std::move(ptr)));
15801516
return ContextDescriptorRef(address, descriptor);
15811517
}
1582-
1518+
1519+
template <typename DescriptorTy>
1520+
bool readFullContextDescriptor(RemoteAddress address,
1521+
MemoryReader::ReadBytesResult &ptr) {
1522+
// Read the full base descriptor if it's bigger than what we have so far.
1523+
if (sizeof(DescriptorTy) > sizeof(TargetContextDescriptor<Runtime>)) {
1524+
ptr = Reader->readObj<DescriptorTy>(address);
1525+
if (!ptr)
1526+
return false;
1527+
}
1528+
1529+
// We don't know how much memory we need to read to get all the trailing
1530+
// objects, but we need to read the memory to figure out how much memory we
1531+
// need to read. Handle this by reading incrementally.
1532+
//
1533+
// We rely on the fact that each trailing object's count depends only on
1534+
// that comes before it. If we've read the first N trailing objects, then we
1535+
// can safely compute the size with N+1 trailing objects. If that size is
1536+
// bigger than what we've read so far, re-read the descriptor with the new
1537+
// size. Once we've walked through all the trailing objects, we've read
1538+
// everything.
1539+
1540+
size_t sizeSoFar = sizeof(DescriptorTy);
1541+
1542+
for (size_t i = 0; i < DescriptorTy::trailingTypeCount(); i++) {
1543+
const DescriptorTy *descriptorSoFar =
1544+
reinterpret_cast<const DescriptorTy *>(ptr.get());
1545+
size_t thisSize = descriptorSoFar->sizeWithTrailingTypeCount(i);
1546+
if (thisSize > sizeSoFar) {
1547+
ptr = Reader->readBytes(address, thisSize);
1548+
if (!ptr)
1549+
return false;
1550+
sizeSoFar = thisSize;
1551+
}
1552+
}
1553+
return true;
1554+
}
1555+
15831556
/// Demangle the entity represented by a symbolic reference to a given symbol name.
15841557
Demangle::NodePointer
15851558
buildContextManglingForSymbol(StringRef symbol, Demangler &dem) {

0 commit comments

Comments
 (0)