Program Listing for File ParameterTree.cpp¶
↰ Return to documentation for file (src/microsoft/shortlist/utils/ParameterTree.cpp
)
#include "microsoft/shortlist/utils/ParameterTree.h"
#include <string>
#include "microsoft/shortlist/utils/StringUtils.h"
#include "microsoft/shortlist/utils/Converter.h"
namespace marian {
namespace quicksand {
#include "microsoft/shortlist/logging/LoggerMacros.h"
std::shared_ptr<ParameterTree> ParameterTree::m_empty_tree = std::make_shared<ParameterTree>("params");
ParameterTree::ParameterTree() {
m_name = "root";
}
ParameterTree::ParameterTree(const std::string& name) {
m_name = name;
}
ParameterTree::~ParameterTree() {
}
void ParameterTree::Clear() {
}
void ParameterTree::ReplaceVariables(
const std::unordered_map<std::string, std::string>& vars,
bool error_on_unknown_vars)
{
ReplaceVariablesInternal(vars, error_on_unknown_vars);
}
void ParameterTree::RegisterInt32(const std::string& name, int32_t * param) {
RegisterItemInternal(name, PARAM_TYPE_INT32, (void *)param);
}
void ParameterTree::RegisterInt64(const std::string& name, int64_t * param) {
RegisterItemInternal(name, PARAM_TYPE_INT64, (void *)param);
}
void ParameterTree::RegisterFloat(const std::string& name, float * param) {
RegisterItemInternal(name, PARAM_TYPE_FLOAT, (void *)param);
}
void ParameterTree::RegisterDouble(const std::string& name, double * param) {
RegisterItemInternal(name, PARAM_TYPE_DOUBLE, (void *)param);
}
void ParameterTree::RegisterBool(const std::string& name, bool * param) {
RegisterItemInternal(name, PARAM_TYPE_BOOL, (void *)param);
}
void ParameterTree::RegisterString(const std::string& name, std::string * param) {
RegisterItemInternal(name, PARAM_TYPE_STRING, (void *)param);
}
std::shared_ptr<ParameterTree> ParameterTree::FromBinaryReader(const void*& current) {
std::shared_ptr<ParameterTree> root = std::make_shared<ParameterTree>();
root->ReadBinary(current);
return root;
}
void ParameterTree::SetRegisteredParams() {
for (std::size_t i = 0; i < m_registered_params.size(); i++) {
const RegisteredParam& rp = m_registered_params[i];
switch (rp.Type()) {
case PARAM_TYPE_INT32:
(*(int32_t *)rp.Data()) = GetInt32Req(rp.Name());
break;
case PARAM_TYPE_INT64:
(*(int64_t *)rp.Data()) = GetInt64Req(rp.Name());
break;
default:
LOG_ERROR_AND_THROW("Unknown ParameterType: %d", (int)rp.Type());
}
}
}
int32_t ParameterTree::GetInt32Or(const std::string& name, int32_t defaultValue) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return defaultValue;
}
return Converter::ToInt32(*value);
}
int64_t ParameterTree::GetInt64Or(const std::string& name, int64_t defaultValue) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return defaultValue;
}
return Converter::ToInt64(*value);
}
uint64_t ParameterTree::GetUInt64Or(const std::string& name, uint64_t defaultValue) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return defaultValue;
}
return Converter::ToUInt64(*value);
}
double ParameterTree::GetDoubleOr(const std::string& name, double defaultValue) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return defaultValue;
}
return Converter::ToDouble(*value);
}
float ParameterTree::GetFloatOr(const std::string& name, float defaultValue) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return defaultValue;
}
return Converter::ToFloat(*value);
}
std::string ParameterTree::GetStringOr(const std::string& name, const std::string& defaultValue) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return defaultValue;
}
return (*value);
}
bool ParameterTree::GetBoolOr(const std::string& name, bool defaultValue) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return defaultValue;
}
return Converter::ToBool(*value);
}
int32_t ParameterTree::GetInt32Req(const std::string& name) const {
std::string value = GetStringReq(name);
return Converter::ToInt32(value);
}
uint64_t ParameterTree::GetUInt64Req(const std::string& name) const {
std::string value = GetStringReq(name);
return Converter::ToUInt64(value);
}
int64_t ParameterTree::GetInt64Req(const std::string& name) const {
std::string value = GetStringReq(name);
return Converter::ToInt64(value);
}
double ParameterTree::GetDoubleReq(const std::string& name) const {
std::string value = GetStringReq(name);
return Converter::ToDouble(value);
}
float ParameterTree::GetFloatReq(const std::string& name) const {
std::string value = GetStringReq(name);
return Converter::ToFloat(value);
}
bool ParameterTree::GetBoolReq(const std::string& name) const {
std::string value = GetStringReq(name);
return Converter::ToBool(value);
}
std::string ParameterTree::GetStringReq(const std::string& name) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
LOG_ERROR_AND_THROW("Required parameter <%s> not found in ParameterTree:\n%s", name.c_str(), ToString().c_str());
}
return (*value);
}
std::vector<std::string> ParameterTree::GetFileListReq(const std::string& name) const {
std::vector<std::string> output = GetFileListOptional(name);
if (output.size() == 0) {
LOG_ERROR_AND_THROW("No files were found for parameter: %s", name.c_str());
}
return output;
}
std::vector<std::string> ParameterTree::GetFileListOptional(const std::string& name) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr || (*value).size() == 0) {
return std::vector<std::string>();
}
std::vector<std::string> all_files = StringUtils::Split(*value, ";");
return all_files;
}
std::vector<std::string> ParameterTree::GetStringListReq(const std::string& name, const std::string& sep) const {
std::string value = GetStringReq(name);
std::vector<std::string> output = StringUtils::Split(value, sep);
return output;
}
std::vector<std::string> ParameterTree::GetStringListOptional(const std::string& name, const std::string& sep) const {
std::string value = GetStringOr(name, "");
std::vector<std::string> output = StringUtils::Split(value, sep);
return output;
}
std::shared_ptr<ParameterTree> ParameterTree::GetChildReq(const std::string& name) const {
for (const auto& child : m_children) {
if (child->Name() == name) {
return child;
}
}
LOG_ERROR_AND_THROW("Unable to find child ParameterTree with name '%s'", name.c_str());
return nullptr; // never happens
}
std::shared_ptr<ParameterTree> ParameterTree::GetChildOrEmpty(const std::string& name) const {
for (const auto& child : m_children) {
if (child->Name() == name) {
return child;
}
}
return std::make_shared<ParameterTree>();
}
// cast current void pointer to T pointer and move forward by num elements
template <typename T>
const T* get(const void*& current, size_t num = 1) {
const T* ptr = (const T*)current;
current = (const T*)current + num;
return ptr;
}
void ParameterTree::ReadBinary(const void*& current) {
auto nameLength = *get<int32_t>(current);
auto nameBytes = get<char>(current, nameLength);
m_name = std::string(nameBytes, nameBytes + nameLength);
auto textLength = *get<int32_t>(current);
auto textBytes = get<char>(current, textLength);
m_text = std::string(textBytes, textBytes + textLength);
int32_t num_children = *get<int32_t>(current);
m_children.resize(num_children);
for (int32_t i = 0; i < num_children; i++) {
m_children[i].reset(new ParameterTree());
m_children[i]->ReadBinary(current);
}
}
std::vector< std::shared_ptr<ParameterTree> > ParameterTree::GetChildren(const std::string& name) const {
std::vector< std::shared_ptr<ParameterTree> > children;
for (std::shared_ptr<ParameterTree> child : m_children) {
if (child->Name() == name) {
children.push_back(child);
}
}
return children;
}
void ParameterTree::AddParam(const std::string& name, const std::string& text) {
std::shared_ptr<ParameterTree> child = std::make_shared<ParameterTree>(name);
child->SetText(text);
m_children.push_back(child);
}
void ParameterTree::SetParam(const std::string& name, const std::string& text) {
for (const auto& child : m_children) {
if (child->Name() == name) {
child->SetText(text);
return;
}
}
std::shared_ptr<ParameterTree> child = std::make_shared<ParameterTree>(name);
child->SetText(text);
m_children.push_back(child);
}
void ParameterTree::AddChild(std::shared_ptr<ParameterTree> child) {
m_children.push_back(child);
}
bool ParameterTree::HasParam(const std::string& name) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return false;
}
return true;
}
bool ParameterTree::HasChild(const std::string& name) const {
for (const auto& child : m_children) {
if (child->Name() == name) {
return true;
}
}
return false;
}
std::string ParameterTree::ToString() const {
std::ostringstream ss;
ToStringInternal(0, ss);
return ss.str();
}
const std::string * ParameterTree::GetParamInternal(const std::string& name) const {
for (const auto& child : m_children) {
if (child->Name() == name) {
return &(child->Text());
}
}
return nullptr;
}
void ParameterTree::RegisterItemInternal(const std::string& name, ParameterType type, void * param) {
if (m_registered_param_names.find(name) != m_registered_param_names.end()) {
LOG_ERROR_AND_THROW("Unable to register duplicate parameter name: '%s'", name.c_str());
}
m_registered_params.push_back(RegisteredParam(name, type, param));
m_registered_param_names.insert(name);
}
void ParameterTree::ToStringInternal(int32_t depth, std::ostream& ss) const {
for (int32_t i = 0; i < 2*depth; i++) {
ss << " ";
}
ss << "<" << m_name << ">";
if (m_children.size() > 0) {
ss << "\n";
for (const std::shared_ptr<ParameterTree>& child : m_children) {
child->ToStringInternal(depth+1, ss);
}
for (int32_t i = 0; i < 2 * depth; i++) {
ss << " ";
}
ss << "</" << m_name << ">\n";
}
else {
ss << m_text << "</" << m_name << ">\n";
}
}
std::shared_ptr<ParameterTree> ParameterTree::Clone() const {
std::shared_ptr<ParameterTree> node = std::make_shared<ParameterTree>(m_name);
node->m_text = m_text;
for (auto& child : m_children) {
node->m_children.push_back(child->Clone());
}
return node;
}
void ParameterTree::Merge(const ParameterTree& other) {
m_name = other.m_name;
m_text = other.m_text;
for (auto& other_child : other.m_children) {
if (HasChild(other_child->Name())) {
auto my_child = GetChildReq(other_child->Name());
if (other_child->Text() != "" && my_child->Text() != "") {
my_child->SetText(other_child->Text());
}
else {
my_child->Merge(*other_child);
}
}
else {
m_children.push_back(other_child->Clone());
}
}
}
void ParameterTree::ReplaceVariablesInternal(
const std::unordered_map<std::string, std::string>& vars,
bool error_on_unknown_vars)
{
std::size_t offset = 0;
std::ostringstream ss;
while (true) {
std::size_t s_pos = m_text.find("$$", offset);
if (s_pos == std::string::npos) {
break;
}
std::size_t e_pos = m_text.find("$$", s_pos + 2);
if (e_pos == std::string::npos) {
break;
}
if (offset != s_pos) {
ss << m_text.substr(offset, s_pos-offset);
}
std::string var_name = m_text.substr(s_pos+2, e_pos - (s_pos+2));
auto it = vars.find(var_name);
if (it != vars.end()) {
std::string value = it->second;
ss << value;
}
else {
if (error_on_unknown_vars) {
LOG_ERROR_AND_THROW("The variable $$%s$$ was not found", var_name.c_str());
}
else {
ss << "$$" << var_name << "$$";
}
}
offset = e_pos + 2;
}
ss << m_text.substr(offset);
m_text = ss.str();
for (auto& child : m_children) {
child->ReplaceVariablesInternal(vars, error_on_unknown_vars);
}
}
} // namespace quicksand
} // namespace marian