[Segmentation] Allow tensor uploading thru metadata

Change-Id: I6f8530cc92623d35fd639eb6060fae0ba148337f
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3846605
Reviewed-by: Siddhartha S <[email protected]>
Commit-Queue: Min Qin <[email protected]>
Cr-Commit-Position: refs/heads/main@{#1038883}
diff --git a/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc b/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
index 6178fc4..5a4488a 100644
--- a/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
+++ b/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
@@ -106,15 +106,14 @@
   histogram_signal_handler_->AddObserver(this);
 
   DCHECK(segments);
-  const base::flat_set<SegmentId>& allowed_ids =
-      SegmentationUkmHelper::GetInstance()->allowed_segment_ids();
   for (const auto& segment : *segments) {
+    const proto::SegmentInfo& segment_info = segment.second;
+
     // Skip the segment if it is not in allowed list.
-    if (!allowed_ids.contains(static_cast<int>(segment.first))) {
+    if (!SegmentationUkmHelper::GetInstance()->CanUploadTensors(segment_info)) {
       continue;
     }
 
-    const proto::SegmentInfo& segment_info = segment.second;
     // Validate segment info.
     auto validation_result = metadata_utils::ValidateSegmentInfo(segment_info);
     if (validation_result !=
diff --git a/components/segmentation_platform/internal/execution/model_executor_impl.cc b/components/segmentation_platform/internal/execution/model_executor_impl.cc
index b752e72..bfb8d57f 100644
--- a/components/segmentation_platform/internal/execution/model_executor_impl.cc
+++ b/components/segmentation_platform/internal/execution/model_executor_impl.cc
@@ -64,6 +64,7 @@
   base::Time total_execution_start_time;
   base::Time model_execution_start_time;
   base::TimeDelta signal_storage_length;
+  bool upload_tensors;
 };
 
 ModelExecutorImpl::ModelExecutionTraceEvent::ModelExecutionTraceEvent(
@@ -125,6 +126,8 @@
       segment_info.model_metadata();
   state->signal_storage_length = model_metadata.signal_storage_length() *
                                  metadata_utils::GetTimeUnit(model_metadata);
+  state->upload_tensors =
+      SegmentationUkmHelper::GetInstance()->CanUploadTensors(segment_info);
   feature_list_query_processor_->ProcessFeatureList(
       segment_info.model_metadata(), request->input_context, segment_id,
       clock_->Now(), FeatureListQueryProcessor::ProcessOption::kInputsOnly,
@@ -184,9 +187,11 @@
     stats::RecordModelExecutionResult(state->segment_id, result.value());
     if (state->model_version && SegmentationUkmHelper::AllowedToUploadData(
                                     state->signal_storage_length, clock_)) {
-      SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
-          state->segment_id, state->model_version, state->input_tensor,
-          result.value());
+      if (state->upload_tensors) {
+        SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
+            state->segment_id, state->model_version, state->input_tensor,
+            result.value());
+      }
     }
     RunModelExecutionCallback(std::move(state), *result,
                               ModelExecutionStatus::kSuccess);
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper.cc b/components/segmentation_platform/internal/segmentation_ukm_helper.cc
index ab51818..324e752f 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper.cc
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper.cc
@@ -115,8 +115,9 @@
 
   // Add inputs to ukm message.
   if (!AddInputsToUkm(&execution_result, segment_id, model_version,
-                      input_tensor))
+                      input_tensor)) {
     return ukm::kInvalidSourceId;
+  }
 
   // TODO(xingliu): Also record continuous outputs for model execution.
   execution_result.SetPredictionResult(FloatToInt64(result))
@@ -162,9 +163,6 @@
     SegmentId segment_id,
     int64_t model_version,
     const std::vector<float>& input_tensor) {
-  if (!allowed_segment_ids_.contains(static_cast<int>(segment_id)))
-    return false;
-
   if (input_tensor.size() > ARRAY_SIZE(kSegmentationUkmInputMethods)) {
     // Don't record UKM if there are too many tensors.
     stats::RecordTooManyInputTensors(input_tensor.size());
@@ -202,6 +200,17 @@
   return true;
 }
 
+bool SegmentationUkmHelper::CanUploadTensors(
+    const proto::SegmentInfo& segment_info) const {
+  if (!base::FeatureList::IsEnabled(
+          features::kSegmentationStructuredMetricsFeature)) {
+    return false;
+  }
+  return segment_info.model_metadata().upload_tensors() ||
+         allowed_segment_ids_.contains(
+             static_cast<int>(segment_info.segment_id()));
+}
+
 // static
 int64_t SegmentationUkmHelper::FloatToInt64(float f) {
   // Encode the float number in IEEE754 double precision.
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper.h b/components/segmentation_platform/internal/segmentation_ukm_helper.h
index 80b2021..73f0ab4 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper.h
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper.h
@@ -59,6 +59,9 @@
       absl::optional<proto::PredictionResult> prediction_result,
       absl::optional<SelectedSegment> selected_segment);
 
+  // Returns whether a segment is allowed to upload training tensors.
+  bool CanUploadTensors(const proto::SegmentInfo& segment_info) const;
+
   // Helper method to encode a float number into int64.
   static int64_t FloatToInt64(float f);
 
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc b/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
index faf59f9..51170dd 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
@@ -17,6 +17,7 @@
 #include "components/segmentation_platform/public/config.h"
 #include "components/segmentation_platform/public/features.h"
 #include "components/segmentation_platform/public/local_state_helper.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "components/segmentation_platform/public/segmentation_platform_service.h"
 #include "components/ukm/test_ukm_recorder.h"
 #include "services/metrics/public/cpp/ukm_builders.h"
@@ -50,6 +51,14 @@
   return result;
 }
 
+proto::SegmentInfo CreateTestSegmentInfo(proto::SegmentId segment_id,
+                                         bool upload_tensors) {
+  proto::SegmentInfo segment_info;
+  segment_info.set_segment_id(segment_id);
+  segment_info.mutable_model_metadata()->set_upload_tensors(upload_tensors);
+  return segment_info;
+}
+
 }  // namespace
 
 class SegmentationUkmHelperTest : public testing::Test {
@@ -163,24 +172,42 @@
                    });
 }
 
-// Tests that recording is disabled if kSegmentationStructuredMetricsFeature
-// is disabled.
+// Tests that tensor uploading is disabled if
+// kSegmentationStructuredMetricsFeature is disabled.
 TEST_F(SegmentationUkmHelperTest, TestDisabledStructuredMetrics) {
   DisableStructureMetrics();
-  std::vector<float> input_tensors = {0.1, 0.7, 0.8, 0.5};
-  SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
-      proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors, 0.6);
-  ExpectEmptyUkmMetrics(Segmentation_ModelExecution::kEntryName);
+  proto::SegmentInfo segment_info = CreateTestSegmentInfo(
+      proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, true);
+  EXPECT_FALSE(
+      SegmentationUkmHelper::GetInstance()->CanUploadTensors(segment_info));
 }
 
-// Tests that recording is disabled for segment IDs that are not in the allowed
-// list.
+// Tests that tensor uploading is disabled for segment IDs that are not in the
+// allowed list.
 TEST_F(SegmentationUkmHelperTest, TestNotAllowedSegmentId) {
   InitializeAllowedSegmentIds("7, 8");
-  std::vector<float> input_tensors = {0.1, 0.7, 0.8, 0.5};
-  SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
-      proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors, 0.6);
-  ExpectEmptyUkmMetrics(Segmentation_ModelExecution::kEntryName);
+  proto::SegmentInfo segment_info = CreateTestSegmentInfo(
+      proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, false);
+  EXPECT_FALSE(
+      SegmentationUkmHelper::GetInstance()->CanUploadTensors(segment_info));
+}
+
+// Tests that tensor uploading is enabled through finch param.
+TEST_F(SegmentationUkmHelperTest, TestUploadTensorsAllowedFromParam) {
+  InitializeAllowedSegmentIds("4, 7, 8");
+  proto::SegmentInfo segment_info = CreateTestSegmentInfo(
+      proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, false);
+  EXPECT_TRUE(
+      SegmentationUkmHelper::GetInstance()->CanUploadTensors(segment_info));
+}
+
+// Tests that tensor uploading is enabled through metadata.
+TEST_F(SegmentationUkmHelperTest, TestUploadTensorsAllowedFromMetadata) {
+  InitializeAllowedSegmentIds("7, 8");
+  proto::SegmentInfo segment_info = CreateTestSegmentInfo(
+      proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, true);
+  EXPECT_TRUE(
+      SegmentationUkmHelper::GetInstance()->CanUploadTensors(segment_info));
 }
 
 // Tests that float encoding works properly.