Skip to content

Commit

Permalink
Merge pull request #9 from petiaccja/improvements
Browse files Browse the repository at this point in the history
improvements
  • Loading branch information
petiaccja authored Oct 31, 2023
2 parents 8de4545 + 7255ef9 commit e818dfa
Show file tree
Hide file tree
Showing 20 changed files with 722 additions and 363 deletions.
1 change: 1 addition & 0 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Language: Cpp
MaxEmptyLinesToKeep: '3'
NamespaceIndentation: Inner
PointerAlignment: Left
PackConstructorInitializers: NextLine
SortIncludes: 'true'
SortUsingDeclarations: 'false'
SpaceAfterCStyleCast: 'false'
Expand Down
9 changes: 0 additions & 9 deletions src/SEDManager/SEDManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,6 @@ std::unordered_map<std::string, uint32_t> SEDManager::GetProperties() {
}


std::vector<NamedObject> SEDManager::GetNamedRows(const Table& table) {
std::vector<NamedObject> namedRows;
for (const auto& row : table) {
namedRows.emplace_back(row.Id(), GetModules().FindName(row.Id()).value_or(to_string(row.Id())));
}
return namedRows;
}


const TPerModules& SEDManager::GetModules() const {
return m_tper->GetModules();
}
Expand Down
1 change: 0 additions & 1 deletion src/SEDManager/SEDManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class SEDManager {
void Reset();

private:
std::vector<NamedObject> GetNamedRows(const Table& table);
void LaunchStack();

private:
Expand Down
61 changes: 37 additions & 24 deletions src/SEDManagerCLI/Interactive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ auto Interactive::ParseGetSet(std::string rowName, int32_t column) const -> std:
throw std::invalid_argument("specify object as 'Table::Object'");
}

const auto tableUid = Unwrap(FindOrParseUid(m_manager, rowNameSections[0], m_currentSecurityProvider), "cannot find table");
const auto maybeRowUid = FindOrParseUid(m_manager, rowName, m_currentSecurityProvider).value_or(FindOrParseUid(m_manager, rowNameSections[1], m_currentSecurityProvider).value_or(0));
const auto tableUid = Unwrap(ParseObjectRef(m_manager, rowNameSections[0], m_currentSecurityProvider), "cannot find table");
const auto maybeRowUid = ParseObjectRef(m_manager, rowName, m_currentSecurityProvider)
.value_or(ParseObjectRef(m_manager, rowNameSections[1], m_currentSecurityProvider)
.value_or(0));
if (maybeRowUid == Uid(0)) {
throw std::invalid_argument("cannot find object");
}
Expand All @@ -108,10 +110,10 @@ void Interactive::PrintCaret() const {
std::vector<std::string> currentAuthNames;
const auto currentSP = GetCurrentSecurityProvider();
if (currentSP) {
currentSPName = m_manager.GetModules().FindName(*currentSP).value_or(to_string(*currentSP));
currentSPName = FormatObjectRef(m_manager, *currentSP);
for (auto auth : GetCurrentAuthorities()) {
const std::string n = m_manager.GetModules().FindName(auth, *currentSP).value_or(to_string(auth));
currentAuthNames.push_back(std::string(SplitName(n).back()));
const std::string authName = FormatObjectRef(m_manager, auth, *currentSP);
currentAuthNames.push_back(std::string(SplitName(authName).back()));
}
}

Expand Down Expand Up @@ -168,7 +170,7 @@ void Interactive::RegisterCallbackStart() {
auto cmd = m_cli.add_subcommand("start", "Start a session with a service provider.");
cmd->add_option("sp", spName, "The name or UID (in hex) of the security provider.")->required();
cmd->callback([this] {
const auto spUid = Unwrap(FindOrParseUid(m_manager, "SP::" + spName, m_currentSecurityProvider), "cannot find security provider");
const auto spUid = Unwrap(ParseObjectRef(m_manager, "SP::" + spName, m_currentSecurityProvider), "cannot find security provider");
m_manager.Start(spUid);
m_currentSecurityProvider = spUid;
});
Expand All @@ -181,7 +183,7 @@ void Interactive::RegisterCallbackAuthenticate() {
auto cmd = m_cli.add_subcommand("auth", "Authenticate with an authority.");
cmd->add_option("authority", authName, "The name or UID (in hex) of the security provider.")->required();
cmd->callback([this] {
const auto authUid = Unwrap(FindOrParseUid(m_manager, "Authority::" + authName, m_currentSecurityProvider), "cannot find authority");
const auto authUid = Unwrap(ParseObjectRef(m_manager, "Authority::" + authName, m_currentSecurityProvider), "cannot find authority");
const auto password = GetPassword("Password: ");
m_manager.Authenticate(authUid, password);
m_currentAuthorities.insert(authUid);
Expand Down Expand Up @@ -280,7 +282,7 @@ void Interactive::RegisterCallbackFind() {
auto cmd = m_cli.add_subcommand("find", "Finds the name and UID of an object given as name or UID.");
cmd->add_option("object", objectName)->required();
cmd->callback([this] {
const auto objectUid = Unwrap(FindOrParseUid(m_manager, objectName, m_currentSecurityProvider), "cannot find object");
const auto objectUid = Unwrap(ParseObjectRef(m_manager, objectName, m_currentSecurityProvider), "cannot find object");
const auto maybeName = m_manager.GetModules().FindName(objectUid, m_currentSecurityProvider);
std::cout << "UID: " << to_string(objectUid) << std::endl;
std::cout << "Name: " << maybeName.value_or("<not found>") << std::endl;
Expand All @@ -294,7 +296,7 @@ void Interactive::RegisterCallbackRows() {
auto cmdRows = m_cli.add_subcommand("rows", "List the rows of the table.");
cmdRows->add_option("table", tableName, "The table to list the rows of.")->required();
cmdRows->callback([this] {
const auto tableUid = Unwrap(FindOrParseUid(m_manager, tableName, m_currentSecurityProvider), "cannot find table");
const auto tableUid = Unwrap(ParseObjectRef(m_manager, tableName, m_currentSecurityProvider), "cannot find table");
const auto table = m_manager.GetTable(tableUid);

const std::vector<std::string> columnNames = { "UID", "Name" };
Expand All @@ -313,7 +315,7 @@ void Interactive::RegisterCallbackColumns() {
auto cmdColumns = m_cli.add_subcommand("columns", "List the columns of the table.");
cmdColumns->add_option("table", tableName, "The table to list the columns of.")->required();
cmdColumns->callback([this] {
const auto tableUid = Unwrap(FindOrParseUid(m_manager, tableName, m_currentSecurityProvider), "cannot find table");
const auto tableUid = Unwrap(ParseObjectRef(m_manager, tableName, m_currentSecurityProvider), "cannot find table");
const auto table = m_manager.GetTable(tableUid);

size_t columnNumber = 0;
Expand All @@ -330,7 +332,7 @@ void Interactive::RegisterCallbackColumns() {
void Interactive::RegisterCallbackGet() {
static std::string rowName;
static int32_t column = -1;
static std::string jsonValue;
static std::string jsonStr;

auto cmdGet = m_cli.add_subcommand("get", "Get a cell from a table.");

Expand All @@ -343,6 +345,7 @@ void Interactive::RegisterCallbackGet() {
}
const auto& [tableUid, rowUid, column] = *parsed;
const auto object = m_manager.GetObject(tableUid, rowUid);
const auto nameConverter = [this](Uid uid) { return m_manager.GetModules().FindName(uid, m_currentSecurityProvider); };
if (column < 0) {
const std::vector<std::string> outColumns = { "Column", "Value" };
std::vector<std::vector<std::string>> outData;
Expand All @@ -351,7 +354,11 @@ void Interactive::RegisterCallbackGet() {
const auto label = std::format("{}: {}", idx, columnDesc.name);
try {
const auto value = *object[idx];
const auto valueStr = value.HasValue() ? ValueToJSON(value, columnDesc.type).dump() : "<empty>";
auto valueStr = value.HasValue() ? ValueToJSON(value, columnDesc.type, nameConverter).dump() : "<empty>";
if (valueStr.size() > 55) {
valueStr.resize(51);
valueStr += " ...";
}
outData.push_back({ label, valueStr });
}
catch (std::exception& ex) {
Expand All @@ -366,7 +373,8 @@ void Interactive::RegisterCallbackGet() {
throw std::invalid_argument("column index is out of bounds.");
}
const auto value = *object[column];
std::cout << (value.HasValue() ? ValueToJSON(value, object.GetDesc()[column].type).dump() : "<empty>") << std::endl;
const auto columnType = object.GetDesc()[column].type;
std::cout << (value.HasValue() ? ValueToJSON(value, columnType, nameConverter).dump(4) : "<empty>") << std::endl;
}
});
}
Expand All @@ -375,25 +383,30 @@ void Interactive::RegisterCallbackGet() {
void Interactive::RegisterCallbackSet() {
static std::string rowName;
static int32_t column = 0;
static std::string jsonValue;
static std::string jsonStr;

auto cmdSet = m_cli.add_subcommand("set", "Set a cell in a table to new value.");

cmdSet->add_option("object", rowName, "The object to set the cells of. Format as 'Table::Object'.")->required();
cmdSet->add_option("column", column, "The column to set.")->required();
auto valueOption = cmdSet->add_option("value", jsonValue, "The new value of the cell.");
auto valueOption = cmdSet->add_option("value", jsonStr, "The new value of the cell.");
cmdSet->callback([this, valueOption] {
const auto parsed = ParseGetSet(rowName, column);
if (!parsed) {
return;
}
if (!*valueOption) {
std::cout << "Reading value until you type 'END' on a new line:" << std::endl;
jsonValue = GetUntilMarker("END");
jsonStr = GetMultiline("END");
}
const auto [tableUid, rowUid, column] = *parsed;
auto object = m_manager.GetObject(tableUid, rowUid);
const auto value = JSONToValue(nlohmann::json::parse(jsonValue), object.GetDesc()[column].type);

const auto jsonObject = nlohmann::json::parse(jsonStr);
const auto columnType = object.GetDesc()[column].type;
const auto nameConverter = [this](std::string_view name) { return m_manager.GetModules().FindUid(name, m_currentSecurityProvider); };
const auto value = JSONToValue(jsonObject, columnType, nameConverter);

object[column] = value;
});
}
Expand All @@ -405,9 +418,9 @@ void Interactive::RegisterCallbackPasswd() {
auto cmd = m_cli.add_subcommand("passwd", "Change the password of an authority.");
cmd->add_option("authority", authName, "The name or UID (in hex) of the authority.")->required();
cmd->callback([this] {
const auto authTable = Unwrap(FindOrParseUid(m_manager, "Authority", m_currentSecurityProvider), "cannot find Authority table");
const auto cPinTable = Unwrap(FindOrParseUid(m_manager, "C_PIN", m_currentSecurityProvider), "cannot find C_PIN table");
const auto authUid = Unwrap(FindOrParseUid(m_manager, "Authority::" + authName, m_currentSecurityProvider), "cannot find authority");
const auto authTable = Unwrap(ParseObjectRef(m_manager, "Authority", m_currentSecurityProvider), "cannot find Authority table");
const auto cPinTable = Unwrap(ParseObjectRef(m_manager, "C_PIN", m_currentSecurityProvider), "cannot find C_PIN table");
const auto authUid = Unwrap(ParseObjectRef(m_manager, "Authority::" + authName, m_currentSecurityProvider), "cannot find authority");
const auto authority = m_manager.GetObject(authTable, authUid);
const auto credentialUid = value_cast<Uid>(*authority[10]);
auto credential = m_manager.GetObject(cPinTable, credentialUid);
Expand All @@ -426,7 +439,7 @@ void Interactive::RegisterCallbackGenMEK() {
auto cmd = m_cli.add_subcommand("gen-mek", "Creates a new Media Encryption Key for a locking range. ERASES RANGE!");
cmd->add_option("range", rangeName, "The locking range.")->required();
cmd->callback([&] {
const auto rangeUid = Unwrap(FindOrParseUid(m_manager, rangeName, m_currentSecurityProvider), "cannot find locking range");
const auto rangeUid = Unwrap(ParseObjectRef(m_manager, rangeName, m_currentSecurityProvider), "cannot find locking range");
m_manager.GenMEK(rangeUid);
});
}
Expand All @@ -438,7 +451,7 @@ void Interactive::RegisterCallbackGenPIN() {
auto cmd = m_cli.add_subcommand("gen-pin", "Creates a new random password for an authority.");
cmd->add_option("c-pin-obj", credentialObj, "The authority's credential object in C_PIN.")->required();
cmd->callback([&] {
const auto credentialUid = Unwrap(FindOrParseUid(m_manager, credentialObj, m_currentSecurityProvider), "cannot find credential object");
const auto credentialUid = Unwrap(ParseObjectRef(m_manager, credentialObj, m_currentSecurityProvider), "cannot find credential object");
m_manager.GenMEK(credentialUid);
});
}
Expand All @@ -449,7 +462,7 @@ void Interactive::RegisterCallbackActivate() {
auto cmd = m_cli.add_subcommand("activate", "Activate an SP that's been disabled the manufacturer.");
cmd->add_option("sp", spName, "The name or UID (in hex) of the security provider.")->required();
cmd->callback([&] {
const auto spUid = Unwrap(FindOrParseUid(m_manager, spName, m_currentSecurityProvider), "cannot find security provider");
const auto spUid = Unwrap(ParseObjectRef(m_manager, spName, m_currentSecurityProvider), "cannot find security provider");
m_manager.Activate(spUid);
});
}
Expand All @@ -460,7 +473,7 @@ void Interactive::RegisterCallbackRevert() {
auto cmd = m_cli.add_subcommand("revert", "Revert an SP to Original Manufacturing State. MAY ERASE DRIVE!");
cmd->add_option("sp", spName, "The name or UID (in hex) of the security provider.")->required();
cmd->callback([&] {
const auto spUid = Unwrap(FindOrParseUid(m_manager, spName, m_currentSecurityProvider), "cannot find security provider");
const auto spUid = Unwrap(ParseObjectRef(m_manager, spName, m_currentSecurityProvider), "cannot find security provider");
m_manager.Revert(spUid);
ClearCurrents();
});
Expand Down
4 changes: 2 additions & 2 deletions src/SEDManagerCLI/PBA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ std::string FormatName(std::string_view name, const std::optional<std::string> c

std::optional<Uid> FindAuthority(SEDManager& manager, std::string_view name) {
const auto lockingSpUid = Unwrap(manager.GetModules().FindUid("SP::Locking"), "could not find Locking SP");
const auto maybeUid = FindOrParseUid(manager, std::format("Authority::{}", name), lockingSpUid);
const auto maybeUid = ParseObjectRef(manager, std::format("Authority::{}", name), lockingSpUid);
if (maybeUid) {
return maybeUid;
}
Expand Down Expand Up @@ -170,7 +170,7 @@ void TryUnlockRanges(SEDManager& manager) {

for (auto lockingRange : lockingRanges) {
const auto uid = lockingRange.Id();
const auto name = manager.GetModules().FindName(uid, lockingSp).value_or(to_string(lockingRange.Id()));
const auto name = FormatObjectRef(manager, lockingRange.Id(), lockingSp);
const auto commonName = GetCommonName(lockingRange, 2);

bool rdUnlocked = false;
Expand Down
21 changes: 13 additions & 8 deletions src/SEDManagerCLI/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ void FlushCin() {
}


std::string GetUntilMarker(std::string_view marker) {
std::string GetMultiline(std::string_view terminator) {
std::string text;
std::string line;
while (std::getline(std::cin, line)) {
if (line == marker) {
if (line == terminator) {
break;
}
text += line;
Expand All @@ -62,17 +62,13 @@ std::string GetUntilMarker(std::string_view marker) {
}


std::optional<Uid> FindOrParseUid(SEDManager& app, std::string_view nameOrUid, std::optional<Uid> sp) {
std::optional<Uid> ParseObjectRef(SEDManager& app, std::string_view nameOrUid, std::optional<Uid> sp) {
const auto maybeUid = app.GetModules().FindUid(nameOrUid, sp);
if (maybeUid) {
return *maybeUid;
}
try {
size_t index = 0;
const uint64_t parsedUid = std::stoull(std::string(nameOrUid), &index, 16);
if (index == nameOrUid.size()) {
return Uid(parsedUid);
}
return stouid(nameOrUid);
}
catch (...) {
// Fallthrough
Expand All @@ -81,6 +77,15 @@ std::optional<Uid> FindOrParseUid(SEDManager& app, std::string_view nameOrUid, s
}


std::string FormatObjectRef(SEDManager& app, Uid uid, std::optional<Uid> sp) {
const auto maybeName = app.GetModules().FindName(uid, sp);
if (maybeName) {
return *maybeName;
}
return to_string(uid);
}


std::vector<std::string_view> SplitName(std::string_view name) {
std::vector<std::string_view> sections;
size_t pos = 0;
Expand Down
8 changes: 4 additions & 4 deletions src/SEDManagerCLI/Utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@


std::vector<std::byte> GetPassword(std::string_view prompt);
std::string GetMultiline(std::string_view terminator);

std::string GetUntilMarker(std::string_view marker);

std::optional<Uid> FindOrParseUid(SEDManager& app, std::string_view nameOrUid, std::optional<Uid> sp = {});
std::optional<Uid> ParseObjectRef(SEDManager& app, std::string_view nameOrUid, std::optional<Uid> sp = {});
std::string FormatObjectRef(SEDManager& app, Uid uid, std::optional<Uid> sp = {});
std::string FormatTable(std::span<const std::string> columns, std::span<const std::vector<std::string>> rows);

std::vector<std::string_view> SplitName(std::string_view name);

std::string FormatTable(std::span<const std::string> columns, std::span<const std::vector<std::string>> rows);

template <class T>
const T& Unwrap(const std::optional<T>& maybeValue, std::string_view message = {}) {
Expand Down
Loading

0 comments on commit e818dfa

Please sign in to comment.