Segmentation platform: Support training data ukm record.
Adds a function in SegmentationUkmHelper to record training data.
Bug: 1295447
Change-Id: I5f62b34ee17ecc3ff1d11c69563a6e95cebbe612
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3482085
Reviewed-by: Min Qin <[email protected]>
Reviewed-by: Siddhartha S <[email protected]>
Commit-Queue: Xing Liu <[email protected]>
Cr-Commit-Position: refs/heads/main@{#974412}
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper.cc b/components/segmentation_platform/internal/segmentation_ukm_helper.cc
index c1c434493..8a99d1b 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper.cc
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper.cc
@@ -23,7 +23,7 @@
using UkmMemberFn =
Segmentation_ModelExecution& (Segmentation_ModelExecution::*)(int64_t);
-const UkmMemberFn kSegmentationUkmMethods[] = {
+const UkmMemberFn kSegmentationUkmInputMethods[] = {
&Segmentation_ModelExecution::SetInput0,
&Segmentation_ModelExecution::SetInput1,
&Segmentation_ModelExecution::SetInput2,
@@ -55,6 +55,14 @@
&Segmentation_ModelExecution::SetInput28,
&Segmentation_ModelExecution::SetInput29};
+const UkmMemberFn kSegmentationUkmOutputMethods[] = {
+ &Segmentation_ModelExecution::SetActualResult,
+ &Segmentation_ModelExecution::SetActualResult2,
+ &Segmentation_ModelExecution::SetActualResult3,
+ &Segmentation_ModelExecution::SetActualResult4,
+ &Segmentation_ModelExecution::SetActualResult5,
+ &Segmentation_ModelExecution::SetActualResult6};
+
// Gets a set of segment IDs that are allowed to upload metrics.
base::flat_set<int> GetSegmentIdsAllowedForReporting() {
std::vector<std::string> segment_ids = base::SplitString(
@@ -98,30 +106,86 @@
int64_t model_version,
const std::vector<float>& input_tensor,
float result) {
- // Check if the |segment_id| is allowed to record metrics.
- if (!allowed_segment_ids_.contains(static_cast<int>(segment_id)))
- return ukm::kInvalidSourceId;
-
ukm::SourceId source_id = ukm::NoURLSourceId();
- // Don't record UKM if there are too many tensors.
- if (input_tensor.size() > ARRAY_SIZE(kSegmentationUkmMethods)) {
- stats::RecordTooManyInputTensors(input_tensor.size());
- return ukm::kInvalidSourceId;
- }
ukm::builders::Segmentation_ModelExecution execution_result(source_id);
- execution_result.SetOptimizationTarget(segment_id)
- .SetModelVersion(model_version);
+ // Add inputs to ukm message.
+ if (!AddInputsToUkm(&execution_result, segment_id, model_version,
+ input_tensor))
+ return ukm::kInvalidSourceId;
- for (size_t i = 0; i < input_tensor.size(); ++i) {
- CALL_MEMBER_FN(execution_result, kSegmentationUkmMethods[i])
- (FloatToInt64(input_tensor[i]));
- }
+ // TODO(xingliu): Also record continuous outputs for model execution.
execution_result.SetPredictionResult(FloatToInt64(result))
.Record(ukm::UkmRecorder::Get());
return source_id;
}
+ukm::SourceId SegmentationUkmHelper::RecordTrainingData(
+ OptimizationTarget segment_id,
+ int64_t model_version,
+ const std::vector<float>& input_tensor,
+ const std::vector<float>& outputs,
+ const std::vector<int>& output_indexes) {
+ ukm::SourceId source_id = ukm::NoURLSourceId();
+ ukm::builders::Segmentation_ModelExecution execution_result(source_id);
+ if (!AddInputsToUkm(&execution_result, segment_id, model_version,
+ input_tensor)) {
+ return ukm::kInvalidSourceId;
+ }
+
+ if (!AddOutputsToUkm(&execution_result, outputs, output_indexes)) {
+ return ukm::kInvalidSourceId;
+ }
+
+ execution_result.Record(ukm::UkmRecorder::Get());
+ return source_id;
+}
+
+bool SegmentationUkmHelper::AddInputsToUkm(
+ ukm::builders::Segmentation_ModelExecution* ukm_builder,
+ OptimizationTarget segment_id,
+ int64_t model_version,
+ const std::vector<float>& input_tensor) {
+ if (!allowed_segment_ids_.contains(static_cast<int>(segment_id)))
+ return false;
+
+ if (input_tensor.size() > ARRAY_SIZE(kSegmentationUkmInputMethods)) {
+ // Don't record UKM if there are too many tensors.
+ stats::RecordTooManyInputTensors(input_tensor.size());
+ return false;
+ }
+
+ ukm_builder->SetOptimizationTarget(segment_id).SetModelVersion(model_version);
+ for (size_t i = 0; i < input_tensor.size(); ++i) {
+ CALL_MEMBER_FN(*ukm_builder, kSegmentationUkmInputMethods[i])
+ (FloatToInt64(input_tensor[i]));
+ }
+ return true;
+}
+
+bool SegmentationUkmHelper::AddOutputsToUkm(
+ ukm::builders::Segmentation_ModelExecution* ukm_builder,
+ const std::vector<float>& outputs,
+ const std::vector<int>& output_indexes) {
+ DCHECK(!outputs.empty());
+ if (outputs.size() != output_indexes.size())
+ return false;
+
+ const int output_methods_size = ARRAY_SIZE(kSegmentationUkmOutputMethods);
+ if (outputs.size() > output_methods_size)
+ return false;
+
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ if (output_indexes[i] >= output_methods_size)
+ return false;
+ CALL_MEMBER_FN(*ukm_builder,
+ kSegmentationUkmOutputMethods[output_indexes[i]])
+ (FloatToInt64(outputs[i]));
+ }
+
+ return true;
+}
+
// static
int64_t SegmentationUkmHelper::FloatToInt64(float f) {
// Encode the float number in IEEE754 double precision.