[Segmentation] Allow tensor uploading thru metadata
Change-Id: I6f8530cc92623d35fd639eb6060fae0ba148337f
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3846605
Reviewed-by: Siddhartha S <[email protected]>
Commit-Queue: Min Qin <[email protected]>
Cr-Commit-Position: refs/heads/main@{#1038883}
diff --git a/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc b/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
index 6178fc4..5a4488a 100644
--- a/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
+++ b/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
@@ -106,15 +106,14 @@
histogram_signal_handler_->AddObserver(this);
DCHECK(segments);
- const base::flat_set<SegmentId>& allowed_ids =
- SegmentationUkmHelper::GetInstance()->allowed_segment_ids();
for (const auto& segment : *segments) {
+ const proto::SegmentInfo& segment_info = segment.second;
+
// Skip the segment if it is not in allowed list.
- if (!allowed_ids.contains(static_cast<int>(segment.first))) {
+ if (!SegmentationUkmHelper::GetInstance()->CanUploadTensors(segment_info)) {
continue;
}
- const proto::SegmentInfo& segment_info = segment.second;
// Validate segment info.
auto validation_result = metadata_utils::ValidateSegmentInfo(segment_info);
if (validation_result !=
diff --git a/components/segmentation_platform/internal/execution/model_executor_impl.cc b/components/segmentation_platform/internal/execution/model_executor_impl.cc
index b752e72..bfb8d57f 100644
--- a/components/segmentation_platform/internal/execution/model_executor_impl.cc
+++ b/components/segmentation_platform/internal/execution/model_executor_impl.cc
@@ -64,6 +64,7 @@
base::Time total_execution_start_time;
base::Time model_execution_start_time;
base::TimeDelta signal_storage_length;
+ bool upload_tensors;
};
ModelExecutorImpl::ModelExecutionTraceEvent::ModelExecutionTraceEvent(
@@ -125,6 +126,8 @@
segment_info.model_metadata();
state->signal_storage_length = model_metadata.signal_storage_length() *
metadata_utils::GetTimeUnit(model_metadata);
+ state->upload_tensors =
+ SegmentationUkmHelper::GetInstance()->CanUploadTensors(segment_info);
feature_list_query_processor_->ProcessFeatureList(
segment_info.model_metadata(), request->input_context, segment_id,
clock_->Now(), FeatureListQueryProcessor::ProcessOption::kInputsOnly,
@@ -184,9 +187,11 @@
stats::RecordModelExecutionResult(state->segment_id, result.value());
if (state->model_version && SegmentationUkmHelper::AllowedToUploadData(
state->signal_storage_length, clock_)) {
- SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
- state->segment_id, state->model_version, state->input_tensor,
- result.value());
+ if (state->upload_tensors) {
+ SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
+ state->segment_id, state->model_version, state->input_tensor,
+ result.value());
+ }
}
RunModelExecutionCallback(std::move(state), *result,
ModelExecutionStatus::kSuccess);
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper.cc b/components/segmentation_platform/internal/segmentation_ukm_helper.cc
index ab51818..324e752f 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper.cc
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper.cc
@@ -115,8 +115,9 @@
// Add inputs to ukm message.
if (!AddInputsToUkm(&execution_result, segment_id, model_version,
- input_tensor))
+ input_tensor)) {
return ukm::kInvalidSourceId;
+ }
// TODO(xingliu): Also record continuous outputs for model execution.
execution_result.SetPredictionResult(FloatToInt64(result))
@@ -162,9 +163,6 @@
SegmentId segment_id,
int64_t model_version,
const std::vector<float>& input_tensor) {
- if (!allowed_segment_ids_.contains(static_cast<int>(segment_id)))
- return false;
-
if (input_tensor.size() > ARRAY_SIZE(kSegmentationUkmInputMethods)) {
// Don't record UKM if there are too many tensors.
stats::RecordTooManyInputTensors(input_tensor.size());
@@ -202,6 +200,17 @@
return true;
}
+bool SegmentationUkmHelper::CanUploadTensors(
+ const proto::SegmentInfo& segment_info) const {
+ if (!base::FeatureList::IsEnabled(
+ features::kSegmentationStructuredMetricsFeature)) {
+ return false;
+ }
+ return segment_info.model_metadata().upload_tensors() ||
+ allowed_segment_ids_.contains(
+ static_cast<int>(segment_info.segment_id()));
+}
+
// static
int64_t SegmentationUkmHelper::FloatToInt64(float f) {
// Encode the float number in IEEE754 double precision.
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper.h b/components/segmentation_platform/internal/segmentation_ukm_helper.h
index 80b2021..73f0ab4 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper.h
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper.h
@@ -59,6 +59,9 @@
absl::optional<proto::PredictionResult> prediction_result,
absl::optional<SelectedSegment> selected_segment);
+ // Returns whether a segment is allowed to upload training tensors.
+ bool CanUploadTensors(const proto::SegmentInfo& segment_info) const;
+
// Helper method to encode a float number into int64.
static int64_t FloatToInt64(float f);
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc b/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
index faf59f9..51170dd 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
@@ -17,6 +17,7 @@
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/features.h"
#include "components/segmentation_platform/public/local_state_helper.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "components/segmentation_platform/public/segmentation_platform_service.h"
#include "components/ukm/test_ukm_recorder.h"
#include "services/metrics/public/cpp/ukm_builders.h"
@@ -50,6 +51,14 @@
return result;
}
+proto::SegmentInfo CreateTestSegmentInfo(proto::SegmentId segment_id,
+ bool upload_tensors) {
+ proto::SegmentInfo segment_info;
+ segment_info.set_segment_id(segment_id);
+ segment_info.mutable_model_metadata()->set_upload_tensors(upload_tensors);
+ return segment_info;
+}
+
} // namespace
class SegmentationUkmHelperTest : public testing::Test {
@@ -163,24 +172,42 @@
});
}
-// Tests that recording is disabled if kSegmentationStructuredMetricsFeature
-// is disabled.
+// Tests that tensor uploading is disabled if
+// kSegmentationStructuredMetricsFeature is disabled.
TEST_F(SegmentationUkmHelperTest, TestDisabledStructuredMetrics) {
DisableStructureMetrics();
- std::vector<float> input_tensors = {0.1, 0.7, 0.8, 0.5};
- SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
- proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors, 0.6);
- ExpectEmptyUkmMetrics(Segmentation_ModelExecution::kEntryName);
+ proto::SegmentInfo segment_info = CreateTestSegmentInfo(
+ proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, true);
+ EXPECT_FALSE(
+ SegmentationUkmHelper::GetInstance()->CanUploadTensors(segment_info));
}
-// Tests that recording is disabled for segment IDs that are not in the allowed
-// list.
+// Tests that tensor uploading is disabled for segment IDs that are not in the
+// allowed list.
TEST_F(SegmentationUkmHelperTest, TestNotAllowedSegmentId) {
InitializeAllowedSegmentIds("7, 8");
- std::vector<float> input_tensors = {0.1, 0.7, 0.8, 0.5};
- SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
- proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors, 0.6);
- ExpectEmptyUkmMetrics(Segmentation_ModelExecution::kEntryName);
+ proto::SegmentInfo segment_info = CreateTestSegmentInfo(
+ proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, false);
+ EXPECT_FALSE(
+ SegmentationUkmHelper::GetInstance()->CanUploadTensors(segment_info));
+}
+
+// Tests that tensor uploading is enabled through finch param.
+TEST_F(SegmentationUkmHelperTest, TestUploadTensorsAllowedFromParam) {
+ InitializeAllowedSegmentIds("4, 7, 8");
+ proto::SegmentInfo segment_info = CreateTestSegmentInfo(
+ proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, false);
+ EXPECT_TRUE(
+ SegmentationUkmHelper::GetInstance()->CanUploadTensors(segment_info));
+}
+
+// Tests that tensor uploading is enabled through metadata.
+TEST_F(SegmentationUkmHelperTest, TestUploadTensorsAllowedFromMetadata) {
+ InitializeAllowedSegmentIds("7, 8");
+ proto::SegmentInfo segment_info = CreateTestSegmentInfo(
+ proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, true);
+ EXPECT_TRUE(
+ SegmentationUkmHelper::GetInstance()->CanUploadTensors(segment_info));
}
// Tests that float encoding works properly.