Segmentation platform: Support training data ukm record.

Adds a function in SegmentationUkmHelper to record training data.

Bug: 1295447
Change-Id: I5f62b34ee17ecc3ff1d11c69563a6e95cebbe612
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3482085
Reviewed-by: Min Qin <[email protected]>
Reviewed-by: Siddhartha S <[email protected]>
Commit-Queue: Xing Liu <[email protected]>
Cr-Commit-Position: refs/heads/main@{#974412}
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper.cc b/components/segmentation_platform/internal/segmentation_ukm_helper.cc
index c1c434493..8a99d1b 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper.cc
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper.cc
@@ -23,7 +23,7 @@
 using UkmMemberFn =
     Segmentation_ModelExecution& (Segmentation_ModelExecution::*)(int64_t);
 
-const UkmMemberFn kSegmentationUkmMethods[] = {
+const UkmMemberFn kSegmentationUkmInputMethods[] = {
     &Segmentation_ModelExecution::SetInput0,
     &Segmentation_ModelExecution::SetInput1,
     &Segmentation_ModelExecution::SetInput2,
@@ -55,6 +55,14 @@
     &Segmentation_ModelExecution::SetInput28,
     &Segmentation_ModelExecution::SetInput29};
 
+const UkmMemberFn kSegmentationUkmOutputMethods[] = {
+    &Segmentation_ModelExecution::SetActualResult,
+    &Segmentation_ModelExecution::SetActualResult2,
+    &Segmentation_ModelExecution::SetActualResult3,
+    &Segmentation_ModelExecution::SetActualResult4,
+    &Segmentation_ModelExecution::SetActualResult5,
+    &Segmentation_ModelExecution::SetActualResult6};
+
 // Gets a set of segment IDs that are allowed to upload metrics.
 base::flat_set<int> GetSegmentIdsAllowedForReporting() {
   std::vector<std::string> segment_ids = base::SplitString(
@@ -98,30 +106,86 @@
     int64_t model_version,
     const std::vector<float>& input_tensor,
     float result) {
-  // Check if the |segment_id| is allowed to record metrics.
-  if (!allowed_segment_ids_.contains(static_cast<int>(segment_id)))
-    return ukm::kInvalidSourceId;
-
   ukm::SourceId source_id = ukm::NoURLSourceId();
-  // Don't record UKM if there are too many tensors.
-  if (input_tensor.size() > ARRAY_SIZE(kSegmentationUkmMethods)) {
-    stats::RecordTooManyInputTensors(input_tensor.size());
-    return ukm::kInvalidSourceId;
-  }
   ukm::builders::Segmentation_ModelExecution execution_result(source_id);
 
-  execution_result.SetOptimizationTarget(segment_id)
-      .SetModelVersion(model_version);
+  // Add inputs to ukm message.
+  if (!AddInputsToUkm(&execution_result, segment_id, model_version,
+                      input_tensor))
+    return ukm::kInvalidSourceId;
 
-  for (size_t i = 0; i < input_tensor.size(); ++i) {
-    CALL_MEMBER_FN(execution_result, kSegmentationUkmMethods[i])
-    (FloatToInt64(input_tensor[i]));
-  }
+  // TODO(xingliu): Also record continuous outputs for model execution.
   execution_result.SetPredictionResult(FloatToInt64(result))
       .Record(ukm::UkmRecorder::Get());
   return source_id;
 }
 
+ukm::SourceId SegmentationUkmHelper::RecordTrainingData(
+    OptimizationTarget segment_id,
+    int64_t model_version,
+    const std::vector<float>& input_tensor,
+    const std::vector<float>& outputs,
+    const std::vector<int>& output_indexes) {
+  ukm::SourceId source_id = ukm::NoURLSourceId();
+  ukm::builders::Segmentation_ModelExecution execution_result(source_id);
+  if (!AddInputsToUkm(&execution_result, segment_id, model_version,
+                      input_tensor)) {
+    return ukm::kInvalidSourceId;
+  }
+
+  if (!AddOutputsToUkm(&execution_result, outputs, output_indexes)) {
+    return ukm::kInvalidSourceId;
+  }
+
+  execution_result.Record(ukm::UkmRecorder::Get());
+  return source_id;
+}
+
+bool SegmentationUkmHelper::AddInputsToUkm(
+    ukm::builders::Segmentation_ModelExecution* ukm_builder,
+    OptimizationTarget segment_id,
+    int64_t model_version,
+    const std::vector<float>& input_tensor) {
+  if (!allowed_segment_ids_.contains(static_cast<int>(segment_id)))
+    return false;
+
+  if (input_tensor.size() > ARRAY_SIZE(kSegmentationUkmInputMethods)) {
+    // Don't record UKM if there are too many tensors.
+    stats::RecordTooManyInputTensors(input_tensor.size());
+    return false;
+  }
+
+  ukm_builder->SetOptimizationTarget(segment_id).SetModelVersion(model_version);
+  for (size_t i = 0; i < input_tensor.size(); ++i) {
+    CALL_MEMBER_FN(*ukm_builder, kSegmentationUkmInputMethods[i])
+    (FloatToInt64(input_tensor[i]));
+  }
+  return true;
+}
+
+bool SegmentationUkmHelper::AddOutputsToUkm(
+    ukm::builders::Segmentation_ModelExecution* ukm_builder,
+    const std::vector<float>& outputs,
+    const std::vector<int>& output_indexes) {
+  DCHECK(!outputs.empty());
+  if (outputs.size() != output_indexes.size())
+    return false;
+
+  const int output_methods_size = ARRAY_SIZE(kSegmentationUkmOutputMethods);
+  if (outputs.size() > output_methods_size)
+    return false;
+
+  for (size_t i = 0; i < outputs.size(); ++i) {
+    if (output_indexes[i] >= output_methods_size)
+      return false;
+    CALL_MEMBER_FN(*ukm_builder,
+                   kSegmentationUkmOutputMethods[output_indexes[i]])
+    (FloatToInt64(outputs[i]));
+  }
+
+  return true;
+}
+
 // static
 int64_t SegmentationUkmHelper::FloatToInt64(float f) {
   // Encode the float number in IEEE754 double precision.