.. _program_listing_file_src_microsoft_shortlist_utils_ParameterTree.cpp: Program Listing for File ParameterTree.cpp ========================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/microsoft/shortlist/utils/ParameterTree.cpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #include "microsoft/shortlist/utils/ParameterTree.h" #include #include "microsoft/shortlist/utils/StringUtils.h" #include "microsoft/shortlist/utils/Converter.h" namespace marian { namespace quicksand { #include "microsoft/shortlist/logging/LoggerMacros.h" std::shared_ptr ParameterTree::m_empty_tree = std::make_shared("params"); ParameterTree::ParameterTree() { m_name = "root"; } ParameterTree::ParameterTree(const std::string& name) { m_name = name; } ParameterTree::~ParameterTree() { } void ParameterTree::Clear() { } void ParameterTree::ReplaceVariables( const std::unordered_map& vars, bool error_on_unknown_vars) { ReplaceVariablesInternal(vars, error_on_unknown_vars); } void ParameterTree::RegisterInt32(const std::string& name, int32_t * param) { RegisterItemInternal(name, PARAM_TYPE_INT32, (void *)param); } void ParameterTree::RegisterInt64(const std::string& name, int64_t * param) { RegisterItemInternal(name, PARAM_TYPE_INT64, (void *)param); } void ParameterTree::RegisterFloat(const std::string& name, float * param) { RegisterItemInternal(name, PARAM_TYPE_FLOAT, (void *)param); } void ParameterTree::RegisterDouble(const std::string& name, double * param) { RegisterItemInternal(name, PARAM_TYPE_DOUBLE, (void *)param); } void ParameterTree::RegisterBool(const std::string& name, bool * param) { RegisterItemInternal(name, PARAM_TYPE_BOOL, (void *)param); } void ParameterTree::RegisterString(const std::string& name, std::string * param) { RegisterItemInternal(name, PARAM_TYPE_STRING, (void *)param); } std::shared_ptr ParameterTree::FromBinaryReader(const void*& current) { std::shared_ptr root = std::make_shared(); root->ReadBinary(current); return root; } void ParameterTree::SetRegisteredParams() { for (std::size_t i = 0; i < m_registered_params.size(); i++) { const RegisteredParam& rp = m_registered_params[i]; switch (rp.Type()) { case PARAM_TYPE_INT32: (*(int32_t *)rp.Data()) = GetInt32Req(rp.Name()); break; case PARAM_TYPE_INT64: (*(int64_t *)rp.Data()) = GetInt64Req(rp.Name()); break; default: LOG_ERROR_AND_THROW("Unknown ParameterType: %d", (int)rp.Type()); } } } int32_t ParameterTree::GetInt32Or(const std::string& name, int32_t defaultValue) const { const std::string * value = GetParamInternal(name); if (value == nullptr) { return defaultValue; } return Converter::ToInt32(*value); } int64_t ParameterTree::GetInt64Or(const std::string& name, int64_t defaultValue) const { const std::string * value = GetParamInternal(name); if (value == nullptr) { return defaultValue; } return Converter::ToInt64(*value); } uint64_t ParameterTree::GetUInt64Or(const std::string& name, uint64_t defaultValue) const { const std::string * value = GetParamInternal(name); if (value == nullptr) { return defaultValue; } return Converter::ToUInt64(*value); } double ParameterTree::GetDoubleOr(const std::string& name, double defaultValue) const { const std::string * value = GetParamInternal(name); if (value == nullptr) { return defaultValue; } return Converter::ToDouble(*value); } float ParameterTree::GetFloatOr(const std::string& name, float defaultValue) const { const std::string * value = GetParamInternal(name); if (value == nullptr) { return defaultValue; } return Converter::ToFloat(*value); } std::string ParameterTree::GetStringOr(const std::string& name, const std::string& defaultValue) const { const std::string * value = GetParamInternal(name); if (value == nullptr) { return defaultValue; } return (*value); } bool ParameterTree::GetBoolOr(const std::string& name, bool defaultValue) const { const std::string * value = GetParamInternal(name); if (value == nullptr) { return defaultValue; } return Converter::ToBool(*value); } int32_t ParameterTree::GetInt32Req(const std::string& name) const { std::string value = GetStringReq(name); return Converter::ToInt32(value); } uint64_t ParameterTree::GetUInt64Req(const std::string& name) const { std::string value = GetStringReq(name); return Converter::ToUInt64(value); } int64_t ParameterTree::GetInt64Req(const std::string& name) const { std::string value = GetStringReq(name); return Converter::ToInt64(value); } double ParameterTree::GetDoubleReq(const std::string& name) const { std::string value = GetStringReq(name); return Converter::ToDouble(value); } float ParameterTree::GetFloatReq(const std::string& name) const { std::string value = GetStringReq(name); return Converter::ToFloat(value); } bool ParameterTree::GetBoolReq(const std::string& name) const { std::string value = GetStringReq(name); return Converter::ToBool(value); } std::string ParameterTree::GetStringReq(const std::string& name) const { const std::string * value = GetParamInternal(name); if (value == nullptr) { LOG_ERROR_AND_THROW("Required parameter <%s> not found in ParameterTree:\n%s", name.c_str(), ToString().c_str()); } return (*value); } std::vector ParameterTree::GetFileListReq(const std::string& name) const { std::vector output = GetFileListOptional(name); if (output.size() == 0) { LOG_ERROR_AND_THROW("No files were found for parameter: %s", name.c_str()); } return output; } std::vector ParameterTree::GetFileListOptional(const std::string& name) const { const std::string * value = GetParamInternal(name); if (value == nullptr || (*value).size() == 0) { return std::vector(); } std::vector all_files = StringUtils::Split(*value, ";"); return all_files; } std::vector ParameterTree::GetStringListReq(const std::string& name, const std::string& sep) const { std::string value = GetStringReq(name); std::vector output = StringUtils::Split(value, sep); return output; } std::vector ParameterTree::GetStringListOptional(const std::string& name, const std::string& sep) const { std::string value = GetStringOr(name, ""); std::vector output = StringUtils::Split(value, sep); return output; } std::shared_ptr ParameterTree::GetChildReq(const std::string& name) const { for (const auto& child : m_children) { if (child->Name() == name) { return child; } } LOG_ERROR_AND_THROW("Unable to find child ParameterTree with name '%s'", name.c_str()); return nullptr; // never happens } std::shared_ptr ParameterTree::GetChildOrEmpty(const std::string& name) const { for (const auto& child : m_children) { if (child->Name() == name) { return child; } } return std::make_shared(); } // cast current void pointer to T pointer and move forward by num elements template const T* get(const void*& current, size_t num = 1) { const T* ptr = (const T*)current; current = (const T*)current + num; return ptr; } void ParameterTree::ReadBinary(const void*& current) { auto nameLength = *get(current); auto nameBytes = get(current, nameLength); m_name = std::string(nameBytes, nameBytes + nameLength); auto textLength = *get(current); auto textBytes = get(current, textLength); m_text = std::string(textBytes, textBytes + textLength); int32_t num_children = *get(current); m_children.resize(num_children); for (int32_t i = 0; i < num_children; i++) { m_children[i].reset(new ParameterTree()); m_children[i]->ReadBinary(current); } } std::vector< std::shared_ptr > ParameterTree::GetChildren(const std::string& name) const { std::vector< std::shared_ptr > children; for (std::shared_ptr child : m_children) { if (child->Name() == name) { children.push_back(child); } } return children; } void ParameterTree::AddParam(const std::string& name, const std::string& text) { std::shared_ptr child = std::make_shared(name); child->SetText(text); m_children.push_back(child); } void ParameterTree::SetParam(const std::string& name, const std::string& text) { for (const auto& child : m_children) { if (child->Name() == name) { child->SetText(text); return; } } std::shared_ptr child = std::make_shared(name); child->SetText(text); m_children.push_back(child); } void ParameterTree::AddChild(std::shared_ptr child) { m_children.push_back(child); } bool ParameterTree::HasParam(const std::string& name) const { const std::string * value = GetParamInternal(name); if (value == nullptr) { return false; } return true; } bool ParameterTree::HasChild(const std::string& name) const { for (const auto& child : m_children) { if (child->Name() == name) { return true; } } return false; } std::string ParameterTree::ToString() const { std::ostringstream ss; ToStringInternal(0, ss); return ss.str(); } const std::string * ParameterTree::GetParamInternal(const std::string& name) const { for (const auto& child : m_children) { if (child->Name() == name) { return &(child->Text()); } } return nullptr; } void ParameterTree::RegisterItemInternal(const std::string& name, ParameterType type, void * param) { if (m_registered_param_names.find(name) != m_registered_param_names.end()) { LOG_ERROR_AND_THROW("Unable to register duplicate parameter name: '%s'", name.c_str()); } m_registered_params.push_back(RegisteredParam(name, type, param)); m_registered_param_names.insert(name); } void ParameterTree::ToStringInternal(int32_t depth, std::ostream& ss) const { for (int32_t i = 0; i < 2*depth; i++) { ss << " "; } ss << "<" << m_name << ">"; if (m_children.size() > 0) { ss << "\n"; for (const std::shared_ptr& child : m_children) { child->ToStringInternal(depth+1, ss); } for (int32_t i = 0; i < 2 * depth; i++) { ss << " "; } ss << "\n"; } else { ss << m_text << "\n"; } } std::shared_ptr ParameterTree::Clone() const { std::shared_ptr node = std::make_shared(m_name); node->m_text = m_text; for (auto& child : m_children) { node->m_children.push_back(child->Clone()); } return node; } void ParameterTree::Merge(const ParameterTree& other) { m_name = other.m_name; m_text = other.m_text; for (auto& other_child : other.m_children) { if (HasChild(other_child->Name())) { auto my_child = GetChildReq(other_child->Name()); if (other_child->Text() != "" && my_child->Text() != "") { my_child->SetText(other_child->Text()); } else { my_child->Merge(*other_child); } } else { m_children.push_back(other_child->Clone()); } } } void ParameterTree::ReplaceVariablesInternal( const std::unordered_map& vars, bool error_on_unknown_vars) { std::size_t offset = 0; std::ostringstream ss; while (true) { std::size_t s_pos = m_text.find("$$", offset); if (s_pos == std::string::npos) { break; } std::size_t e_pos = m_text.find("$$", s_pos + 2); if (e_pos == std::string::npos) { break; } if (offset != s_pos) { ss << m_text.substr(offset, s_pos-offset); } std::string var_name = m_text.substr(s_pos+2, e_pos - (s_pos+2)); auto it = vars.find(var_name); if (it != vars.end()) { std::string value = it->second; ss << value; } else { if (error_on_unknown_vars) { LOG_ERROR_AND_THROW("The variable $$%s$$ was not found", var_name.c_str()); } else { ss << "$$" << var_name << "$$"; } } offset = e_pos + 2; } ss << m_text.substr(offset); m_text = ss.str(); for (auto& child : m_children) { child->ReplaceVariablesInternal(vars, error_on_unknown_vars); } } } // namespace quicksand } // namespace marian