[Segmentation] Check that validation messages only contain data after UKM approval

The input tensors may contain data before UKM approval. This CL
will verify that this doesn't happen.

Bug: 1327419
Change-Id: I30749e20f229fcc011c285d4900bd4f7a9a84771
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3655131
Commit-Queue: Min Qin <[email protected]>
Reviewed-by: Siddhartha S <[email protected]>
Cr-Commit-Position: refs/heads/main@{#1005600}
diff --git a/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc b/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
index c69b277..8298b050 100644
--- a/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
+++ b/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
@@ -235,10 +235,8 @@
   base::TimeDelta signal_storage_length =
       model_metadata.signal_storage_length() *
       metadata_utils::GetTimeUnit(model_metadata);
-  if (LocalStateHelper::GetInstance().GetPrefTime(
-          kSegmentationUkmMostRecentAllowedTimeKey) +
-          signal_storage_length >=
-      clock_->Now()) {
+  if (!SegmentationUkmHelper::AllowedToUploadData(signal_storage_length,
+                                                  clock_)) {
     RecordTrainingDataCollectionEvent(
         segment_info.segment_id(),
         stats::TrainingDataCollectionEvent::kPartialDataNotAllowed);
diff --git a/components/segmentation_platform/internal/execution/model_executor_impl.cc b/components/segmentation_platform/internal/execution/model_executor_impl.cc
index 1e7f30c..8db1f8f 100644
--- a/components/segmentation_platform/internal/execution/model_executor_impl.cc
+++ b/components/segmentation_platform/internal/execution/model_executor_impl.cc
@@ -62,6 +62,7 @@
   std::vector<float> input_tensor;
   base::Time total_execution_start_time;
   base::Time model_execution_start_time;
+  base::TimeDelta signal_storage_length;
 };
 
 ModelExecutorImpl::ModelExecutionTraceEvent::ModelExecutionTraceEvent(
@@ -120,6 +121,10 @@
   }
 
   state->model_version = segment_info.model_version();
+  const proto::SegmentationModelMetadata& model_metadata =
+      segment_info.model_metadata();
+  state->signal_storage_length = model_metadata.signal_storage_length() *
+                                 metadata_utils::GetTimeUnit(model_metadata);
   feature_list_query_processor_->ProcessFeatureList(
       segment_info.model_metadata(), segment_id, clock_->Now(),
       FeatureListQueryProcessor::ProcessOption::kInputsOnly,
@@ -180,7 +185,8 @@
             << optimization_guide::proto::OptimizationTarget_Name(
                    state->segment_id);
     stats::RecordModelExecutionResult(state->segment_id, result.value());
-    if (state->model_version) {
+    if (state->model_version && SegmentationUkmHelper::AllowedToUploadData(
+                                    state->signal_storage_length, clock_)) {
       SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
           state->segment_id, state->model_version, state->input_tensor,
           result.value());
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper.cc b/components/segmentation_platform/internal/segmentation_ukm_helper.cc
index e976700a..24bc080 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper.cc
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper.cc
@@ -8,10 +8,13 @@
 #include "base/metrics/field_trial_params.h"
 #include "base/strings/string_number_conversions.h"
 #include "base/strings/string_split.h"
+#include "base/time/clock.h"
+#include "components/segmentation_platform/internal/constants.h"
 #include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
 #include "components/segmentation_platform/internal/stats.h"
 #include "components/segmentation_platform/public/config.h"
 #include "components/segmentation_platform/public/features.h"
+#include "components/segmentation_platform/public/local_state_helper.h"
 #include "services/metrics/public/cpp/ukm_builders.h"
 #include "services/metrics/public/cpp/ukm_recorder.h"
 
@@ -204,4 +207,14 @@
   return base::bit_cast<int64_t>(static_cast<double>(f));
 }
 
+// static
+bool SegmentationUkmHelper::AllowedToUploadData(
+    base::TimeDelta signal_storage_length,
+    base::Clock* clock) {
+  return LocalStateHelper::GetInstance().GetPrefTime(
+             kSegmentationUkmMostRecentAllowedTimeKey) +
+             signal_storage_length <
+         clock->Now();
+}
+
 }  // namespace segmentation_platform
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper.h b/components/segmentation_platform/internal/segmentation_ukm_helper.h
index 05f33c0..4ceb6e5 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper.h
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper.h
@@ -7,6 +7,7 @@
 
 #include "base/containers/flat_set.h"
 #include "base/no_destructor.h"
+#include "base/time/time.h"
 #include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
 #include "services/metrics/public/cpp/ukm_source_id.h"
@@ -14,6 +15,10 @@
 
 using optimization_guide::proto::OptimizationTarget;
 
+namespace base {
+class Clock;
+}
+
 namespace ukm::builders {
 class Segmentation_ModelExecution;
 }  // namespace ukm::builders
@@ -57,6 +62,11 @@
   // Helper method to encode a float number into int64.
   static int64_t FloatToInt64(float f);
 
+  // Helper method to check if data is allowed to upload through ukm
+  // given a clock and the signal storage length.
+  static bool AllowedToUploadData(base::TimeDelta signal_storage_length,
+                                  base::Clock* clock);
+
   // Gets a set of segment IDs that are allowed to upload metrics.
   const base::flat_set<OptimizationTarget>& allowed_segment_ids() {
     return allowed_segment_ids_;
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc b/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
index b73ff78..8a3d59ff 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
@@ -9,10 +9,15 @@
 #include "base/bit_cast.h"
 #include "base/test/metrics/histogram_tester.h"
 #include "base/test/scoped_feature_list.h"
+#include "base/test/simple_test_clock.h"
 #include "base/test/task_environment.h"
+#include "components/prefs/testing_pref_service.h"
+#include "components/segmentation_platform/internal/constants.h"
 #include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
 #include "components/segmentation_platform/public/config.h"
 #include "components/segmentation_platform/public/features.h"
+#include "components/segmentation_platform/public/local_state_helper.h"
+#include "components/segmentation_platform/public/segmentation_platform_service.h"
 #include "components/ukm/test_ukm_recorder.h"
 #include "services/metrics/public/cpp/ukm_builders.h"
 #include "testing/gtest/include/gtest/gtest.h"
@@ -256,4 +261,23 @@
   ASSERT_NE(source_id, ukm::kInvalidSourceId);
 }
 
+TEST_F(SegmentationUkmHelperTest, AllowedToUploadData) {
+  TestingPrefServiceSimple prefs;
+  SegmentationPlatformService::RegisterLocalStatePrefs(prefs.registry());
+  LocalStateHelper::GetInstance().Initialize(&prefs);
+
+  base::SimpleTestClock clock;
+  clock.SetNow(base::Time::Now());
+  LocalStateHelper::GetInstance().SetPrefTime(
+      kSegmentationUkmMostRecentAllowedTimeKey, clock.Now());
+
+  ASSERT_FALSE(
+      SegmentationUkmHelper::AllowedToUploadData(base::Seconds(1), &clock));
+  clock.Advance(base::Seconds(10));
+  ASSERT_TRUE(
+      SegmentationUkmHelper::AllowedToUploadData(base::Seconds(1), &clock));
+  ASSERT_FALSE(
+      SegmentationUkmHelper::AllowedToUploadData(base::Seconds(11), &clock));
+}
+
 }  // namespace segmentation_platform