blob: ba64075a1043c47d1bd28d390300a17f996cc22a [file] [log] [blame]
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/40285824): Remove this and convert code to safer constructs.
#pragma allow_unsafe_buffers
#endif
#include "components/segmentation_platform/internal/metadata/metadata_writer.h"
#include <cstddef>
#include <optional>
#include <vector>
#include "base/check.h"
#include "base/metrics/metrics_hashes.h"
#include "base/strings/strcat.h"
#include "components/segmentation_platform/public/constants.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
namespace segmentation_platform {
namespace {
void FillCustomInput(const MetadataWriter::CustomInput feature,
proto::CustomInput& input) {
input.set_tensor_length(feature.tensor_length);
input.set_fill_policy(feature.fill_policy);
for (size_t i = 0; i < feature.default_values_size; ++i) {
input.add_default_value(feature.default_values[i]);
}
if (feature.name) {
input.set_name(feature.name);
}
for (size_t i = 0; i < feature.arg_size; ++i) {
(*input.mutable_additional_args())[feature.arg[i].first] =
std::string(feature.arg[i].second);
}
}
template <typename StringVector>
void PopulateMultiClassClassifier(
proto::Predictor::MultiClassClassifier* multi_class_classifier,
const StringVector& class_labels,
int top_k_outputs) {
multi_class_classifier->set_top_k_outputs(top_k_outputs);
for (const auto& class_label : class_labels) {
multi_class_classifier->mutable_class_labels()->Add(
std::string(class_label));
}
}
} // namespace
MetadataWriter::MetadataWriter(proto::SegmentationModelMetadata* metadata)
: metadata_(metadata) {}
MetadataWriter::~MetadataWriter() = default;
void MetadataWriter::AddUmaFeatures(const UMAFeature features[],
size_t features_size,
bool is_output) {
for (size_t i = 0; i < features_size; i++) {
const auto& feature = features[i];
proto::UMAFeature* uma_feature;
if (is_output) {
auto* training_output =
metadata_->mutable_training_outputs()->add_outputs();
uma_feature =
training_output->mutable_uma_output()->mutable_uma_feature();
} else {
auto* input_feature = metadata_->add_input_features();
uma_feature = input_feature->mutable_uma_feature();
}
uma_feature->set_type(feature.signal_type);
uma_feature->set_name(feature.name);
uma_feature->set_name_hash(base::HashMetricName(feature.name));
uma_feature->set_bucket_count(feature.bucket_count);
uma_feature->set_tensor_length(feature.tensor_length);
uma_feature->set_aggregation(feature.aggregation);
for (size_t j = 0; j < feature.enum_ids_size; j++) {
uma_feature->add_enum_ids(feature.accepted_enum_ids[j]);
}
for (size_t j = 0; j < feature.default_values_size; j++) {
uma_feature->add_default_values(feature.default_values[j]);
}
}
}
proto::SqlFeature* MetadataWriter::AddSqlFeature(const SqlFeature& feature) {
proto::SqlFeature* proto =
metadata_->add_input_features()->mutable_sql_feature();
proto->set_sql(feature.sql);
for (size_t ev = 0; ev < feature.events_size; ++ev) {
const auto& event = feature.events[ev];
auto* ukm_event = proto->mutable_signal_filter()->add_ukm_events();
ukm_event->set_event_hash(event.event_hash.GetUnsafeValue());
for (size_t m = 0; m < event.metrics_size; ++m) {
ukm_event->mutable_metric_hash_filter()->Add(
event.metrics[m].GetUnsafeValue());
}
}
return proto;
}
proto::SqlFeature* MetadataWriter::AddSqlFeature(
const SqlFeature& feature,
const BindValues& bind_values) {
auto* proto = AddSqlFeature(feature);
unsigned index = 0;
for (const auto& it : bind_values) {
auto* value = proto->add_bind_values();
for (unsigned i = index; i < index + it.second.tensor_length; ++i) {
value->add_bind_field_index(i);
}
index += it.second.tensor_length;
value->set_param_type(it.first);
FillCustomInput(it.second, *value->mutable_value());
}
return proto;
}
proto::CustomInput* MetadataWriter::AddCustomInput(const CustomInput& feature) {
proto::CustomInput* proto =
metadata_->add_input_features()->mutable_custom_input();
FillCustomInput(feature, *proto);
return proto;
}
void MetadataWriter::AddDiscreteMappingEntries(
const std::string& key,
const std::pair<float, int>* mappings,
size_t mappings_size) {
auto* discrete_mappings = metadata_->mutable_discrete_mappings();
for (size_t i = 0; i < mappings_size; i++) {
auto* discrete_mapping_entry = (*discrete_mappings)[key].add_entries();
discrete_mapping_entry->set_min_result(mappings[i].first);
discrete_mapping_entry->set_rank(mappings[i].second);
}
}
void MetadataWriter::AddBooleanSegmentDiscreteMapping(const std::string& key) {
const int selected_rank = 1;
const float model_score = 1;
const std::pair<float, int> mappings[]{{model_score, selected_rank}};
AddDiscreteMappingEntries(key, mappings, 1);
}
void MetadataWriter::AddBooleanSegmentDiscreteMappingWithSubsegments(
const std::string& key,
float threshold,
int max_value) {
DCHECK_GT(threshold, 0);
// Should record at least 2 subsegments.
DCHECK_GT(max_value, 1);
const int selected_rank = 1;
const std::pair<float, int> mappings[]{{threshold, selected_rank}};
AddDiscreteMappingEntries(key, mappings, 1);
std::vector<std::pair<float, int>> subsegment_mapping;
for (int i = 1; i <= max_value; ++i) {
subsegment_mapping.emplace_back(i, i);
}
AddDiscreteMappingEntries(
base::StrCat({key, kSubsegmentDiscreteMappingSuffix}),
subsegment_mapping.data(), subsegment_mapping.size());
}
void MetadataWriter::SetSegmentationMetadataConfig(
proto::TimeUnit time_unit,
uint64_t bucket_duration,
int64_t signal_storage_length,
int64_t min_signal_collection_length,
int64_t result_time_to_live) {
metadata_->set_time_unit(time_unit);
metadata_->set_bucket_duration(bucket_duration);
metadata_->set_signal_storage_length(signal_storage_length);
metadata_->set_min_signal_collection_length(min_signal_collection_length);
metadata_->set_result_time_to_live(result_time_to_live);
}
void MetadataWriter::SetDefaultSegmentationMetadataConfig(
int min_signal_collection_length_days,
int signal_storage_length_days) {
SetSegmentationMetadataConfig(proto::TimeUnit::DAY, /*bucket_duration=*/1,
signal_storage_length_days,
min_signal_collection_length_days,
/*result_time_to_live=*/1);
}
void MetadataWriter::AddOutputConfigForBinaryClassifier(
float threshold,
const std::string& positive_label,
const std::string& negative_label) {
proto::Predictor::BinaryClassifier* binary_classifier =
metadata_->mutable_output_config()
->mutable_predictor()
->mutable_binary_classifier();
binary_classifier->set_threshold(threshold);
binary_classifier->set_positive_label(positive_label);
binary_classifier->set_negative_label(negative_label);
}
void MetadataWriter::SetIgnorePreviousModelTTLInOutputConfig() {
metadata_->mutable_output_config()->set_ignore_previous_model_ttl(true);
}
void MetadataWriter::AddOutputConfigForMultiClassClassifier(
base::span<const char* const> class_labels,
int top_k_outputs,
std::optional<float> threshold) {
proto::Predictor::MultiClassClassifier* multi_class_classifier =
metadata_->mutable_output_config()
->mutable_predictor()
->mutable_multi_class_classifier();
PopulateMultiClassClassifier(multi_class_classifier, class_labels,
top_k_outputs);
if (threshold.has_value()) {
multi_class_classifier->set_threshold(threshold.value());
}
}
void MetadataWriter::AddOutputConfigForMultiClassClassifier(
const std::vector<std::string>& class_labels,
int top_k_outputs,
std::optional<float> threshold) {
proto::Predictor::MultiClassClassifier* multi_class_classifier =
metadata_->mutable_output_config()
->mutable_predictor()
->mutable_multi_class_classifier();
PopulateMultiClassClassifier(multi_class_classifier, class_labels,
top_k_outputs);
if (threshold.has_value()) {
multi_class_classifier->set_threshold(threshold.value());
}
}
void MetadataWriter::AddOutputConfigForMultiClassClassifier(
base::span<const char* const> class_labels,
int top_k_outputs,
const base::span<float> per_class_thresholds) {
CHECK_EQ(class_labels.size(), per_class_thresholds.size());
proto::Predictor::MultiClassClassifier* multi_class_classifier =
metadata_->mutable_output_config()
->mutable_predictor()
->mutable_multi_class_classifier();
PopulateMultiClassClassifier(multi_class_classifier, class_labels,
top_k_outputs);
for (float per_class_threshold : per_class_thresholds) {
multi_class_classifier->add_class_thresholds(per_class_threshold);
}
}
void MetadataWriter::AddOutputConfigForBinnedClassifier(
const std::vector<std::pair<float, std::string>>& bins,
std::string underflow_label) {
proto::Predictor::BinnedClassifier* binned_classifier =
metadata_->mutable_output_config()
->mutable_predictor()
->mutable_binned_classifier();
binned_classifier->set_underflow_label(underflow_label);
for (const std::pair<float, std::string>& bin : bins) {
proto::Predictor::BinnedClassifier::Bin* current_bin =
binned_classifier->add_bins();
current_bin->set_min_range(bin.first);
current_bin->set_label(bin.second);
}
}
void MetadataWriter::AddOutputConfigForGenericPredictor(
const std::vector<std::string>& labels) {
proto::Predictor::GenericPredictor* generic_predictor =
metadata_->mutable_output_config()
->mutable_predictor()
->mutable_generic_predictor();
generic_predictor->mutable_output_labels()->Assign(labels.begin(),
labels.end());
}
void MetadataWriter::AddPredictedResultTTLInOutputConfig(
std::vector<std::pair<std::string, std::int64_t>> top_label_to_ttl_list,
int64_t default_ttl,
proto::TimeUnit time_unit) {
proto::PredictedResultTTL* predicted_result_ttl =
metadata_->mutable_output_config()->mutable_predicted_result_ttl();
predicted_result_ttl->set_time_unit(time_unit);
predicted_result_ttl->set_default_ttl(default_ttl);
auto* top_label_to_ttl_map =
predicted_result_ttl->mutable_top_label_to_ttl_map();
for (const std::pair<std::string, int64_t>& label_to_ttl :
top_label_to_ttl_list) {
(*top_label_to_ttl_map)[label_to_ttl.first] = label_to_ttl.second;
}
}
void MetadataWriter::AddDelayTrigger(uint64_t delay_sec) {
auto* config =
metadata_->mutable_training_outputs()->mutable_trigger_config();
auto* trigger = config->add_observation_trigger();
trigger->set_delay_sec(delay_sec);
config->set_decision_type(proto::TrainingOutputs::TriggerConfig::ONDEMAND);
}
void MetadataWriter::AddFromInputContext(const char* custom_input_name,
const char* additional_args_name) {
proto::CustomInput* custom_input = AddCustomInput(MetadataWriter::CustomInput{
.tensor_length = 1,
.fill_policy = proto::CustomInput::FILL_FROM_INPUT_CONTEXT,
.name = custom_input_name});
(*custom_input->mutable_additional_args())["name"] = additional_args_name;
}
} // namespace segmentation_platform