[Segmentation] Check that validation messages only contain data after UKM approval
The input tensors may contain data before UKM approval. This CL
will verify that this doesn't happen.
Bug: 1327419
Change-Id: I30749e20f229fcc011c285d4900bd4f7a9a84771
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3655131
Commit-Queue: Min Qin <[email protected]>
Reviewed-by: Siddhartha S <[email protected]>
Cr-Commit-Position: refs/heads/main@{#1005600}
diff --git a/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc b/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
index c69b277..8298b050 100644
--- a/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
+++ b/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
@@ -235,10 +235,8 @@
base::TimeDelta signal_storage_length =
model_metadata.signal_storage_length() *
metadata_utils::GetTimeUnit(model_metadata);
- if (LocalStateHelper::GetInstance().GetPrefTime(
- kSegmentationUkmMostRecentAllowedTimeKey) +
- signal_storage_length >=
- clock_->Now()) {
+ if (!SegmentationUkmHelper::AllowedToUploadData(signal_storage_length,
+ clock_)) {
RecordTrainingDataCollectionEvent(
segment_info.segment_id(),
stats::TrainingDataCollectionEvent::kPartialDataNotAllowed);
diff --git a/components/segmentation_platform/internal/execution/model_executor_impl.cc b/components/segmentation_platform/internal/execution/model_executor_impl.cc
index 1e7f30c..8db1f8f 100644
--- a/components/segmentation_platform/internal/execution/model_executor_impl.cc
+++ b/components/segmentation_platform/internal/execution/model_executor_impl.cc
@@ -62,6 +62,7 @@
std::vector<float> input_tensor;
base::Time total_execution_start_time;
base::Time model_execution_start_time;
+ base::TimeDelta signal_storage_length;
};
ModelExecutorImpl::ModelExecutionTraceEvent::ModelExecutionTraceEvent(
@@ -120,6 +121,10 @@
}
state->model_version = segment_info.model_version();
+ const proto::SegmentationModelMetadata& model_metadata =
+ segment_info.model_metadata();
+ state->signal_storage_length = model_metadata.signal_storage_length() *
+ metadata_utils::GetTimeUnit(model_metadata);
feature_list_query_processor_->ProcessFeatureList(
segment_info.model_metadata(), segment_id, clock_->Now(),
FeatureListQueryProcessor::ProcessOption::kInputsOnly,
@@ -180,7 +185,8 @@
<< optimization_guide::proto::OptimizationTarget_Name(
state->segment_id);
stats::RecordModelExecutionResult(state->segment_id, result.value());
- if (state->model_version) {
+ if (state->model_version && SegmentationUkmHelper::AllowedToUploadData(
+ state->signal_storage_length, clock_)) {
SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
state->segment_id, state->model_version, state->input_tensor,
result.value());
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper.cc b/components/segmentation_platform/internal/segmentation_ukm_helper.cc
index e976700a..24bc080 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper.cc
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper.cc
@@ -8,10 +8,13 @@
#include "base/metrics/field_trial_params.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_split.h"
+#include "base/time/clock.h"
+#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/features.h"
+#include "components/segmentation_platform/public/local_state_helper.h"
#include "services/metrics/public/cpp/ukm_builders.h"
#include "services/metrics/public/cpp/ukm_recorder.h"
@@ -204,4 +207,14 @@
return base::bit_cast<int64_t>(static_cast<double>(f));
}
+// static
+bool SegmentationUkmHelper::AllowedToUploadData(
+ base::TimeDelta signal_storage_length,
+ base::Clock* clock) {
+ return LocalStateHelper::GetInstance().GetPrefTime(
+ kSegmentationUkmMostRecentAllowedTimeKey) +
+ signal_storage_length <
+ clock->Now();
+}
+
} // namespace segmentation_platform
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper.h b/components/segmentation_platform/internal/segmentation_ukm_helper.h
index 05f33c0..4ceb6e5 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper.h
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper.h
@@ -7,6 +7,7 @@
#include "base/containers/flat_set.h"
#include "base/no_destructor.h"
+#include "base/time/time.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "services/metrics/public/cpp/ukm_source_id.h"
@@ -14,6 +15,10 @@
using optimization_guide::proto::OptimizationTarget;
+namespace base {
+class Clock;
+}
+
namespace ukm::builders {
class Segmentation_ModelExecution;
} // namespace ukm::builders
@@ -57,6 +62,11 @@
// Helper method to encode a float number into int64.
static int64_t FloatToInt64(float f);
+ // Helper method to check if data is allowed to upload through ukm
+ // given a clock and the signal storage length.
+ static bool AllowedToUploadData(base::TimeDelta signal_storage_length,
+ base::Clock* clock);
+
// Gets a set of segment IDs that are allowed to upload metrics.
const base::flat_set<OptimizationTarget>& allowed_segment_ids() {
return allowed_segment_ids_;
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc b/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
index b73ff78..8a3d59ff 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
@@ -9,10 +9,15 @@
#include "base/bit_cast.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h"
+#include "base/test/simple_test_clock.h"
#include "base/test/task_environment.h"
+#include "components/prefs/testing_pref_service.h"
+#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/features.h"
+#include "components/segmentation_platform/public/local_state_helper.h"
+#include "components/segmentation_platform/public/segmentation_platform_service.h"
#include "components/ukm/test_ukm_recorder.h"
#include "services/metrics/public/cpp/ukm_builders.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -256,4 +261,23 @@
ASSERT_NE(source_id, ukm::kInvalidSourceId);
}
+TEST_F(SegmentationUkmHelperTest, AllowedToUploadData) {
+ TestingPrefServiceSimple prefs;
+ SegmentationPlatformService::RegisterLocalStatePrefs(prefs.registry());
+ LocalStateHelper::GetInstance().Initialize(&prefs);
+
+ base::SimpleTestClock clock;
+ clock.SetNow(base::Time::Now());
+ LocalStateHelper::GetInstance().SetPrefTime(
+ kSegmentationUkmMostRecentAllowedTimeKey, clock.Now());
+
+ ASSERT_FALSE(
+ SegmentationUkmHelper::AllowedToUploadData(base::Seconds(1), &clock));
+ clock.Advance(base::Seconds(10));
+ ASSERT_TRUE(
+ SegmentationUkmHelper::AllowedToUploadData(base::Seconds(1), &clock));
+ ASSERT_FALSE(
+ SegmentationUkmHelper::AllowedToUploadData(base::Seconds(11), &clock));
+}
+
} // namespace segmentation_platform