[segmentation_platform] Store results in prefs
This CL adds
1- Implemented SegmentationResultPrefs for storing segmentation results.
2- Register prefs with chrome that will be loaded with profile.
3- Set up finch params to be used in place of constants.
Bug: 1218522
Change-Id: Ia5a071f71754092931a2a6ed4589e7afe8ec80de
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2929114
Reviewed-by: Colin Blundell <[email protected]>
Reviewed-by: Tommy Nyquist <[email protected]>
Commit-Queue: Shakti Sahu <[email protected]>
Cr-Commit-Position: refs/heads/master@{#893497}
diff --git a/chrome/browser/prefs/browser_prefs.cc b/chrome/browser/prefs/browser_prefs.cc
index 7c2f450..4ecdb06 100644
--- a/chrome/browser/prefs/browser_prefs.cc
+++ b/chrome/browser/prefs/browser_prefs.cc
@@ -139,6 +139,7 @@
#include "components/search_engines/template_url_prepopulate_data.h"
#include "components/security_interstitials/content/insecure_form_blocking_page.h"
#include "components/security_interstitials/content/stateful_ssl_host_state_delegate.h"
+#include "components/segmentation_platform/public/segmentation_platform_service.h"
#include "components/sessions/core/session_id_generator.h"
#include "components/signin/public/identity_manager/identity_manager.h"
#include "components/site_engagement/content/site_engagement_service.h"
@@ -1023,6 +1024,8 @@
registry);
security_interstitials::InsecureFormBlockingPage::RegisterProfilePrefs(
registry);
+ segmentation_platform::SegmentationPlatformService::RegisterProfilePrefs(
+ registry);
SessionStartupPref::RegisterProfilePrefs(registry);
SharingSyncPreference::RegisterProfilePrefs(registry);
site_engagement::SiteEngagementService::RegisterProfilePrefs(registry);
diff --git a/components/segmentation_platform/DEPS b/components/segmentation_platform/DEPS
index 46291ef..98829d9f 100644
--- a/components/segmentation_platform/DEPS
+++ b/components/segmentation_platform/DEPS
@@ -3,6 +3,7 @@
"+components/leveldb_proto",
"+components/optimization_guide",
"-components/optimization_guide/content",
+ "+components/prefs",
"+third_party/tflite",
"+third_party/tflite-support",
]
diff --git a/components/segmentation_platform/components_unittests.filter b/components/segmentation_platform/components_unittests.filter
index 9fc49e7..2d35f0e 100644
--- a/components/segmentation_platform/components_unittests.filter
+++ b/components/segmentation_platform/components_unittests.filter
@@ -6,7 +6,9 @@
ModelExecutionManagerTest.*
ModelExecutionSchedulerTest.*
SegmentationModelExecutorTest.*
+SegmentationPlatformFeaturesTest.*
SegmentationPlatformServiceImplTest.*
+SegmentationResultPrefsTest.*
SegmentInfoDatabaseTest.*
SegmentSelectorTest.*
SignalDatabaseImplTest.*
diff --git a/components/segmentation_platform/internal/BUILD.gn b/components/segmentation_platform/internal/BUILD.gn
index 7a2d9a5..9982713 100644
--- a/components/segmentation_platform/internal/BUILD.gn
+++ b/components/segmentation_platform/internal/BUILD.gn
@@ -15,6 +15,8 @@
]
sources = [
+ "constants.cc",
+ "constants.h",
"database/metadata_utils.cc",
"database/metadata_utils.h",
"database/segment_info_database.cc",
@@ -40,6 +42,8 @@
"scheduler/model_execution_scheduler.h",
"scheduler/model_execution_scheduler_impl.cc",
"scheduler/model_execution_scheduler_impl.h",
+ "segmentation_platform_features.cc",
+ "segmentation_platform_features.h",
"segmentation_platform_service_impl.cc",
"segmentation_platform_service_impl.h",
"selection/segment_selector.h",
@@ -57,8 +61,10 @@
deps = [
"//base",
+ "//base/util/values:values_util",
"//components/keyed_service/core",
"//components/leveldb_proto",
+ "//components/prefs",
"//components/segmentation_platform/internal/proto",
"//components/segmentation_platform/public",
]
@@ -107,8 +113,10 @@
"execution/feature_aggregator_impl_unittest.cc",
"execution/model_execution_manager_factory_unittest.cc",
"scheduler/model_execution_scheduler_unittest.cc",
+ "segmentation_platform_features_unittest.cc",
"segmentation_platform_service_impl_unittest.cc",
"selection/segment_selector_unittest.cc",
+ "selection/segmentation_result_prefs_unittest.cc",
"signals/histogram_signal_handler_unittest.cc",
"signals/signal_filter_processor_unittest.cc",
"signals/user_action_signal_handler_unittest.cc",
@@ -120,6 +128,8 @@
"//base/test:test_support",
"//components/leveldb_proto:test_support",
"//components/optimization_guide/core:test_support",
+ "//components/prefs",
+ "//components/prefs:test_support",
"//components/segmentation_platform/internal/proto",
"//components/segmentation_platform/public",
"//testing/gmock",
diff --git a/components/segmentation_platform/internal/constants.cc b/components/segmentation_platform/internal/constants.cc
new file mode 100644
index 0000000..d767b4b
--- /dev/null
+++ b/components/segmentation_platform/internal/constants.cc
@@ -0,0 +1,13 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/segmentation_platform/internal/constants.h"
+
+namespace segmentation_platform {
+
+const char kAdaptiveToolbarSegmentationKey[] = "adaptive_toolbar";
+const char kSegmentationResultPref[] =
+ "segmentation_platform.segmentation_result";
+
+} // namespace segmentation_platform
diff --git a/components/segmentation_platform/internal/constants.h b/components/segmentation_platform/internal/constants.h
new file mode 100644
index 0000000..852cafff
--- /dev/null
+++ b/components/segmentation_platform/internal/constants.h
@@ -0,0 +1,18 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_CONSTANTS_H_
+#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_CONSTANTS_H_
+
+namespace segmentation_platform {
+
+// The key to be used to find discrete mapping for adaptive toolbar feature.
+extern const char kAdaptiveToolbarSegmentationKey[];
+
+// The path to the pref storing the segmentation result.
+extern const char kSegmentationResultPref[];
+
+} // namespace segmentation_platform
+
+#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_CONSTANTS_H_
diff --git a/components/segmentation_platform/internal/database/metadata_utils.cc b/components/segmentation_platform/internal/database/metadata_utils.cc
index 71c6ddb..b8f340db 100644
--- a/components/segmentation_platform/internal/database/metadata_utils.cc
+++ b/components/segmentation_platform/internal/database/metadata_utils.cc
@@ -5,16 +5,10 @@
#include "components/segmentation_platform/internal/database/metadata_utils.h"
#include "base/notreached.h"
+#include "components/segmentation_platform/internal/segmentation_platform_features.h"
namespace segmentation_platform {
namespace metadata_utils {
-namespace {
-// Used to determine if the model was executed too recently to run again.
-// TODO(shaktisahu): Make this finch configurable.
-constexpr base::TimeDelta kFreshResultsDurationThreshold =
- base::TimeDelta::FromHours(24);
-
-} // namespace
ValidationResult ValidateSegmentInfo(const proto::SegmentInfo& segment_info) {
if (!segment_info.has_segment_id())
@@ -59,7 +53,7 @@
segment_info.prediction_result().timestamp_us()));
return base::Time::Now() - last_result_timestamp <
- kFreshResultsDurationThreshold;
+ features::GetMinDelayForModelRerun();
}
base::TimeDelta GetTimeUnit(
diff --git a/components/segmentation_platform/internal/database/segment_info_database.h b/components/segmentation_platform/internal/database/segment_info_database.h
index a31dd67..44e15c0 100644
--- a/components/segmentation_platform/internal/database/segment_info_database.h
+++ b/components/segmentation_platform/internal/database/segment_info_database.h
@@ -23,9 +23,6 @@
class PredictionResult;
} // namespace proto
-// The key to be used to find discrete mapping for segmentation.
-constexpr char kSegmentationDiscreteMappingKey[] = "segmentation";
-
// Represents a DB layer that stores model metadata and prediction results to
// the disk.
class SegmentInfoDatabase {
diff --git a/components/segmentation_platform/internal/database/test_segment_info_database.cc b/components/segmentation_platform/internal/database/test_segment_info_database.cc
index c33d7a7..007b5a1 100644
--- a/components/segmentation_platform/internal/database/test_segment_info_database.cc
+++ b/components/segmentation_platform/internal/database/test_segment_info_database.cc
@@ -8,6 +8,7 @@
#include "base/metrics/metrics_hashes.h"
#include "components/optimization_guide/proto/models.pb.h"
+#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
@@ -112,14 +113,15 @@
timestamp.ToDeltaSinceWindowsEpoch().InMicroseconds());
}
-void TestSegmentInfoDatabase::AddDiscreteMapping(OptimizationTarget segment_id,
- float mappings[][2],
- int num_pairs) {
+void TestSegmentInfoDatabase::AddDiscreteMapping(
+ OptimizationTarget segment_id,
+ float mappings[][2],
+ int num_pairs,
+ const std::string& discrete_mapping_key) {
proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
auto* discrete_mappings_map =
info->mutable_model_metadata()->mutable_discrete_mappings();
- auto& discrete_mappings =
- (*discrete_mappings_map)[kSegmentationDiscreteMappingKey];
+ auto& discrete_mappings = (*discrete_mappings_map)[discrete_mapping_key];
for (int i = 0; i < num_pairs; i++) {
auto* pair = mappings[i];
auto* entry = discrete_mappings.add_entries();
diff --git a/components/segmentation_platform/internal/database/test_segment_info_database.h b/components/segmentation_platform/internal/database/test_segment_info_database.h
index 4fc8edd..c4698216 100644
--- a/components/segmentation_platform/internal/database/test_segment_info_database.h
+++ b/components/segmentation_platform/internal/database/test_segment_info_database.h
@@ -47,7 +47,8 @@
base::Time timestamp);
void AddDiscreteMapping(OptimizationTarget segment_id,
float mappings[][2],
- int num_pairs);
+ int num_pairs,
+ const std::string& discrete_mapping_key);
void SetBucketDuration(OptimizationTarget segment_id,
int64_t bucket_duration,
proto::TimeUnit time_unit);
diff --git a/components/segmentation_platform/internal/segmentation_platform_features.cc b/components/segmentation_platform/internal/segmentation_platform_features.cc
new file mode 100644
index 0000000..5a1a3b2
--- /dev/null
+++ b/components/segmentation_platform/internal/segmentation_platform_features.cc
@@ -0,0 +1,41 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/segmentation_platform/internal/segmentation_platform_features.h"
+
+namespace segmentation_platform {
+namespace features {
+namespace {
+
+// Default min delay in seconds between two successful executions of a model.
+// Default is 12 hours.
+constexpr int kDefaultMinDelayForModelRerunSeconds = 43200;
+
+// Default TTL for segment selection.
+constexpr int kDefaultSegmentSelectionTTLDays = 28;
+
+} // namespace
+
+// Core feature flag for segmentation platform.
+const base::Feature kSegmentationPlatformFeature{
+ "SegmentationPlatform", base::FEATURE_ENABLED_BY_DEFAULT};
+
+base::TimeDelta GetMinDelayForModelRerun() {
+ int min_delay_seconds = base::GetFieldTrialParamByFeatureAsInt(
+ kSegmentationPlatformFeature, "min_delay_for_model_rerun_seconds",
+ kDefaultMinDelayForModelRerunSeconds);
+
+ return base::TimeDelta::FromSeconds(min_delay_seconds);
+}
+
+base::TimeDelta GetSegmentSelectionTTL() {
+ int segment_selection_ttl_days = base::GetFieldTrialParamByFeatureAsInt(
+ kSegmentationPlatformFeature, "segment_selection_ttl_days",
+ kDefaultSegmentSelectionTTLDays);
+
+ return base::TimeDelta::FromDays(segment_selection_ttl_days);
+}
+
+} // namespace features
+} // namespace segmentation_platform
diff --git a/components/segmentation_platform/internal/segmentation_platform_features.h b/components/segmentation_platform/internal/segmentation_platform_features.h
new file mode 100644
index 0000000..04546dea
--- /dev/null
+++ b/components/segmentation_platform/internal/segmentation_platform_features.h
@@ -0,0 +1,26 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SEGMENTATION_PLATFORM_FEATURES_H_
+#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SEGMENTATION_PLATFORM_FEATURES_H_
+
+#include "base/feature_list.h"
+#include "base/time/time.h"
+
+namespace segmentation_platform {
+namespace features {
+
+extern const base::Feature kSegmentationPlatformFeature;
+
+// Used to determine if the model was executed too recently to run again.
+base::TimeDelta GetMinDelayForModelRerun();
+
+// Time to live for a segment selection. Segment selection can't be changed
+// before this duration.
+base::TimeDelta GetSegmentSelectionTTL();
+
+} // namespace features
+} // namespace segmentation_platform
+
+#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SEGMENTATION_PLATFORM_FEATURES_H_
diff --git a/components/segmentation_platform/internal/segmentation_platform_features_unittest.cc b/components/segmentation_platform/internal/segmentation_platform_features_unittest.cc
new file mode 100644
index 0000000..4f544393
--- /dev/null
+++ b/components/segmentation_platform/internal/segmentation_platform_features_unittest.cc
@@ -0,0 +1,27 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/segmentation_platform/internal/segmentation_platform_features.h"
+
+#include "base/test/scoped_feature_list.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace segmentation_platform {
+
+class SegmentationPlatformFeaturesTest : public testing::Test {
+ public:
+ SegmentationPlatformFeaturesTest() = default;
+ ~SegmentationPlatformFeaturesTest() override = default;
+};
+
+TEST_F(SegmentationPlatformFeaturesTest, DefaultValues) {
+ base::test::ScopedFeatureList feature_list;
+ feature_list.InitAndEnableFeature(features::kSegmentationPlatformFeature);
+ EXPECT_EQ(base::TimeDelta::FromSeconds(43200),
+ features::GetMinDelayForModelRerun());
+ EXPECT_EQ(base::TimeDelta::FromDays(28), features::GetSegmentSelectionTTL());
+}
+
+} // namespace segmentation_platform
diff --git a/components/segmentation_platform/internal/segmentation_platform_service_impl.cc b/components/segmentation_platform/internal/segmentation_platform_service_impl.cc
index e4c57a2..4137763 100644
--- a/components/segmentation_platform/internal/segmentation_platform_service_impl.cc
+++ b/components/segmentation_platform/internal/segmentation_platform_service_impl.cc
@@ -4,10 +4,19 @@
#include "components/segmentation_platform/internal/segmentation_platform_service_impl.h"
+#include "components/prefs/pref_registry_simple.h"
+#include "components/segmentation_platform/internal/constants.h"
+
namespace segmentation_platform {
SegmentationPlatformServiceImpl::SegmentationPlatformServiceImpl() = default;
SegmentationPlatformServiceImpl::~SegmentationPlatformServiceImpl() = default;
+// static
+void SegmentationPlatformService::RegisterProfilePrefs(
+ PrefRegistrySimple* registry) {
+ registry->RegisterDictionaryPref(kSegmentationResultPref);
+}
+
} // namespace segmentation_platform
diff --git a/components/segmentation_platform/internal/selection/segment_selector_impl.cc b/components/segmentation_platform/internal/selection/segment_selector_impl.cc
index dc6cdb3..b56dc8a 100644
--- a/components/segmentation_platform/internal/selection/segment_selector_impl.cc
+++ b/components/segmentation_platform/internal/selection/segment_selector_impl.cc
@@ -4,27 +4,30 @@
#include "components/segmentation_platform/internal/selection/segment_selector_impl.h"
-#include "base/logging.h"
#include "base/threading/thread_task_runner_handle.h"
+#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
+#include "components/segmentation_platform/internal/segmentation_platform_features.h"
#include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
namespace segmentation_platform {
SegmentSelectorImpl::SegmentSelectorImpl(SegmentInfoDatabase* segment_database,
- SegmentationResultPrefs* result_prefs)
+ SegmentationResultPrefs* result_prefs,
+ const std::string& segmentation_key)
: segment_database_(segment_database),
result_prefs_(result_prefs),
+ segmentation_key_(segmentation_key),
initialized_(false) {}
SegmentSelectorImpl::~SegmentSelectorImpl() = default;
void SegmentSelectorImpl::Initialize(base::OnceClosure callback) {
// Read selected segment from prefs.
- absl::optional<SelectedSegment> selected_segment =
- result_prefs_->ReadSegmentationResultFromPref();
+ const auto& selected_segment =
+ result_prefs_->ReadSegmentationResultFromPref(segmentation_key_);
if (selected_segment.has_value())
selected_segment_last_session_ = selected_segment->segment_id;
@@ -90,7 +93,7 @@
continue;
DCHECK(info.prediction_result().has_result());
- int score = ConvertToDiscreteScore(id, kSegmentationDiscreteMappingKey,
+ int score = ConvertToDiscreteScore(id, segmentation_key_,
info.prediction_result().result(),
info.model_metadata());
if (score > max_score) {
@@ -106,28 +109,29 @@
void SegmentSelectorImpl::UpdateSelectedSegment(
OptimizationTarget new_selection) {
- absl::optional<SelectedSegment> previous_selection =
- result_prefs_->ReadSegmentationResultFromPref();
+ const auto& previous_selection =
+ result_prefs_->ReadSegmentationResultFromPref(segmentation_key_);
bool skip_updating_prefs = false;
if (previous_selection.has_value()) {
- skip_updating_prefs =
- new_selection == previous_selection->segment_id ||
- (previous_selection->selection_time + kSegmentSelectionTTL >
- base::Time::Now());
+ skip_updating_prefs = new_selection == previous_selection->segment_id ||
+ (previous_selection->selection_time +
+ features::GetSegmentSelectionTTL() >
+ base::Time::Now());
// TODO(shaktisahu): Use segment selection inertia.
}
if (skip_updating_prefs)
return;
- // Write result to prefs.
+ // Write result to prefs. Delete if no valid selection.
absl::optional<SelectedSegment> updated_selection;
if (new_selection != OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN) {
updated_selection = absl::make_optional<SelectedSegment>(new_selection);
updated_selection->selection_time = base::Time::Now();
}
- result_prefs_->SaveSegmentationResultToPref(updated_selection);
+ result_prefs_->SaveSegmentationResultToPref(segmentation_key_,
+ updated_selection);
}
void SegmentSelectorImpl::ReadScoresFromLastSession(
diff --git a/components/segmentation_platform/internal/selection/segment_selector_impl.h b/components/segmentation_platform/internal/selection/segment_selector_impl.h
index a02d7b6..1de310e9 100644
--- a/components/segmentation_platform/internal/selection/segment_selector_impl.h
+++ b/components/segmentation_platform/internal/selection/segment_selector_impl.h
@@ -20,12 +20,11 @@
class SegmentationModelMetadata;
} // namespace proto
-constexpr base::TimeDelta kSegmentSelectionTTL = base::TimeDelta::FromDays(28);
-
class SegmentSelectorImpl : public SegmentSelector {
public:
SegmentSelectorImpl(SegmentInfoDatabase* segment_database,
- SegmentationResultPrefs* result_prefs);
+ SegmentationResultPrefs* result_prefs,
+ const std::string& segmentation_key);
~SegmentSelectorImpl() override;
@@ -83,6 +82,10 @@
// Helper class to read/write results to the prefs.
SegmentationResultPrefs* result_prefs_;
+ // The key specific to this selection, and used for finding the discrete
+ // mapping and writing to prefs.
+ const std::string segmentation_key_;
+
// These values are read from prefs or db on init and used for serving the
// clients in the current session.
absl::optional<OptimizationTarget> selected_segment_last_session_;
diff --git a/components/segmentation_platform/internal/selection/segment_selector_unittest.cc b/components/segmentation_platform/internal/selection/segment_selector_unittest.cc
index 4ca027a..44117b0 100644
--- a/components/segmentation_platform/internal/selection/segment_selector_unittest.cc
+++ b/components/segmentation_platform/internal/selection/segment_selector_unittest.cc
@@ -6,6 +6,7 @@
#include "base/run_loop.h"
#include "base/test/task_environment.h"
+#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/database/test_segment_info_database.h"
#include "components/segmentation_platform/internal/scheduler/model_execution_scheduler.h"
@@ -21,13 +22,13 @@
class MockSegmentationResultPrefs : public SegmentationResultPrefs {
public:
- MockSegmentationResultPrefs() = default;
+ MockSegmentationResultPrefs() : SegmentationResultPrefs(nullptr) {}
MOCK_METHOD(void,
SaveSegmentationResultToPref,
- (const absl::optional<SelectedSegment>&));
+ (const std::string&, const absl::optional<SelectedSegment>&));
MOCK_METHOD(absl::optional<SelectedSegment>,
ReadSegmentationResultFromPref,
- ());
+ (const std::string&));
};
class MockModelExecutionScheduler : public ModelExecutionScheduler {
@@ -52,7 +53,7 @@
segment_database_ = std::make_unique<test::TestSegmentInfoDatabase>();
prefs_ = std::make_unique<MockSegmentationResultPrefs>();
segment_selector_ = std::make_unique<SegmentSelectorImpl>(
- segment_database_.get(), prefs_.get());
+ segment_database_.get(), prefs_.get(), kAdaptiveToolbarSegmentationKey);
segment_selector_->set_model_execution_scheduler(
&model_execution_scheduler_);
}
@@ -91,44 +92,48 @@
TEST_F(SegmentSelectorTest, CheckDiscreteMapping) {
OptimizationTarget segment_id =
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+ std::string segmentation_key = kAdaptiveToolbarSegmentationKey;
float mapping[][2] = {{0.2, 1}, {0.5, 3}, {0.7, 4}};
- segment_database_->AddDiscreteMapping(segment_id, mapping, 3);
+ segment_database_->AddDiscreteMapping(segment_id, mapping, 3,
+ segmentation_key);
proto::SegmentInfo* segment_info =
segment_database_->FindOrCreateSegment(segment_id);
const proto::SegmentationModelMetadata& metadata =
segment_info->model_metadata();
- ASSERT_EQ(0,
- ConvertToDiscreteScore(segment_id, "segmentation", 0.1, metadata));
- ASSERT_EQ(1,
- ConvertToDiscreteScore(segment_id, "segmentation", 0.4, metadata));
- ASSERT_EQ(3,
- ConvertToDiscreteScore(segment_id, "segmentation", 0.5, metadata));
- ASSERT_EQ(3,
- ConvertToDiscreteScore(segment_id, "segmentation", 0.6, metadata));
- ASSERT_EQ(4,
- ConvertToDiscreteScore(segment_id, "segmentation", 0.9, metadata));
+ ASSERT_EQ(
+ 0, ConvertToDiscreteScore(segment_id, segmentation_key, 0.1, metadata));
+ ASSERT_EQ(
+ 1, ConvertToDiscreteScore(segment_id, segmentation_key, 0.4, metadata));
+ ASSERT_EQ(
+ 3, ConvertToDiscreteScore(segment_id, segmentation_key, 0.5, metadata));
+ ASSERT_EQ(
+ 3, ConvertToDiscreteScore(segment_id, segmentation_key, 0.6, metadata));
+ ASSERT_EQ(
+ 4, ConvertToDiscreteScore(segment_id, segmentation_key, 0.9, metadata));
}
TEST_F(SegmentSelectorTest, FindBestSegmentFlowWithTwoSegments) {
OptimizationTarget segment_id =
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
float mapping[][2] = {{0.2, 1}, {0.5, 3}, {0.7, 4}};
- segment_database_->AddDiscreteMapping(segment_id, mapping, 3);
+ segment_database_->AddDiscreteMapping(segment_id, mapping, 3,
+ kAdaptiveToolbarSegmentationKey);
OptimizationTarget segment_id2 =
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
float mapping2[][2] = {{0.3, 1}, {0.4, 4}};
- segment_database_->AddDiscreteMapping(segment_id2, mapping2, 2);
+ segment_database_->AddDiscreteMapping(segment_id2, mapping2, 2,
+ kAdaptiveToolbarSegmentationKey);
base::Time result_timestamp = base::Time::Now();
segment_database_->AddPredictionResult(segment_id, 0.6, result_timestamp);
segment_database_->AddPredictionResult(segment_id2, 0.5, result_timestamp);
absl::optional<SelectedSegment> selected_segment;
- EXPECT_CALL(*prefs_, SaveSegmentationResultToPref(_))
+ EXPECT_CALL(*prefs_, SaveSegmentationResultToPref(_, _))
.Times(1)
- .WillOnce(SaveArg<0>(&selected_segment));
+ .WillOnce(SaveArg<1>(&selected_segment));
segment_selector_->OnModelExecutionCompleted(segment_id);
ASSERT_TRUE(selected_segment.has_value());
@@ -139,15 +144,16 @@
OptimizationTarget segment_id1 =
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
float mapping1[][2] = {{0.2, 1}, {0.5, 3}, {0.7, 4}};
- segment_database_->AddDiscreteMapping(segment_id1, mapping1, 3);
+ segment_database_->AddDiscreteMapping(segment_id1, mapping1, 3,
+ kAdaptiveToolbarSegmentationKey);
base::Time result_timestamp = base::Time::Now();
segment_database_->AddPredictionResult(segment_id1, 0.6, result_timestamp);
absl::optional<SelectedSegment> selected_segment;
- EXPECT_CALL(*prefs_, SaveSegmentationResultToPref(_))
+ EXPECT_CALL(*prefs_, SaveSegmentationResultToPref(_, _))
.Times(1)
- .WillOnce(SaveArg<0>(&selected_segment));
+ .WillOnce(SaveArg<1>(&selected_segment));
segment_selector_->OnModelExecutionCompleted(segment_id1);
ASSERT_TRUE(selected_segment.has_value());
@@ -157,12 +163,13 @@
OptimizationTarget segment_id2 =
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
float mapping2[][2] = {{0.3, 1}, {0.4, 4}};
- segment_database_->AddDiscreteMapping(segment_id2, mapping2, 2);
+ segment_database_->AddDiscreteMapping(segment_id2, mapping2, 2,
+ kAdaptiveToolbarSegmentationKey);
segment_database_->AddPredictionResult(segment_id2, 0.5, result_timestamp);
- EXPECT_CALL(*prefs_, SaveSegmentationResultToPref(_))
+ EXPECT_CALL(*prefs_, SaveSegmentationResultToPref(_, _))
.Times(1)
- .WillOnce(SaveArg<0>(&selected_segment));
+ .WillOnce(SaveArg<1>(&selected_segment));
segment_selector_->OnModelExecutionCompleted(segment_id2);
ASSERT_TRUE(selected_segment.has_value());
@@ -175,7 +182,7 @@
OptimizationTarget segment_id0 =
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
SelectedSegment from_history(segment_id0);
- EXPECT_CALL(*prefs_, ReadSegmentationResultFromPref())
+ EXPECT_CALL(*prefs_, ReadSegmentationResultFromPref(_))
.WillRepeatedly(Return(from_history));
base::RunLoop loop;
@@ -188,15 +195,16 @@
OptimizationTarget segment_id1 =
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
float mapping1[][2] = {{0.2, 1}, {0.5, 3}, {0.7, 4}};
- segment_database_->AddDiscreteMapping(segment_id1, mapping1, 3);
+ segment_database_->AddDiscreteMapping(segment_id1, mapping1, 3,
+ kAdaptiveToolbarSegmentationKey);
base::Time result_timestamp = base::Time::Now();
segment_database_->AddPredictionResult(segment_id1, 0.6, result_timestamp);
absl::optional<SelectedSegment> selected_segment;
- EXPECT_CALL(*prefs_, SaveSegmentationResultToPref(_))
+ EXPECT_CALL(*prefs_, SaveSegmentationResultToPref(_, _))
.Times(1)
- .WillOnce(SaveArg<0>(&selected_segment));
+ .WillOnce(SaveArg<1>(&selected_segment));
segment_selector_->OnModelExecutionCompleted(segment_id1);
ASSERT_TRUE(selected_segment.has_value());
diff --git a/components/segmentation_platform/internal/selection/segmentation_result_prefs.cc b/components/segmentation_platform/internal/selection/segmentation_result_prefs.cc
index 330f815..0287165 100644
--- a/components/segmentation_platform/internal/selection/segmentation_result_prefs.cc
+++ b/components/segmentation_platform/internal/selection/segmentation_result_prefs.cc
@@ -4,9 +4,67 @@
#include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
+#include "base/util/values/values_util.h"
+#include "components/prefs/pref_service.h"
+#include "components/prefs/scoped_user_pref_update.h"
+#include "components/segmentation_platform/internal/constants.h"
+#include "components/segmentation_platform/public/segmentation_platform_service.h"
+
namespace segmentation_platform {
SelectedSegment::SelectedSegment(OptimizationTarget segment_id)
: segment_id(segment_id), in_use(false) {}
+SelectedSegment::~SelectedSegment() = default;
+
+SegmentationResultPrefs::SegmentationResultPrefs(PrefService* pref_service)
+ : prefs_(pref_service) {}
+
+void SegmentationResultPrefs::SaveSegmentationResultToPref(
+ const std::string& result_key,
+ const absl::optional<SelectedSegment>& selected_segment) {
+ DictionaryPrefUpdate update(prefs_, kSegmentationResultPref);
+ base::DictionaryValue* dictionary = update.Get();
+ if (!selected_segment.has_value()) {
+ dictionary->RemoveKey(result_key);
+ return;
+ }
+
+ base::Value segmentation_result(base::Value::Type::DICTIONARY);
+ segmentation_result.SetIntKey("segment_id", selected_segment->segment_id);
+ segmentation_result.SetBoolKey("in_use", selected_segment->in_use);
+ segmentation_result.SetKey(
+ "selection_time", util::TimeToValue(selected_segment->selection_time));
+ dictionary->SetKey(result_key, std::move(segmentation_result));
+}
+
+absl::optional<SelectedSegment>
+SegmentationResultPrefs::ReadSegmentationResultFromPref(
+ const std::string& result_key) {
+ const base::DictionaryValue* dictionary =
+ prefs_->GetDictionary(kSegmentationResultPref);
+ DCHECK(dictionary);
+
+ const base::Value* value = dictionary->FindKey(result_key);
+ if (!value)
+ return absl::nullopt;
+
+ const base::DictionaryValue& segmentation_result =
+ base::Value::AsDictionaryValue(*value);
+
+ absl::optional<int> segment_id = segmentation_result.FindIntKey("segment_id");
+ absl::optional<bool> in_use = segmentation_result.FindBoolKey("in_use");
+ absl::optional<base::Time> selection_time =
+ util::ValueToTime(segmentation_result.FindPath("selection_time"));
+
+ SelectedSegment selected_segment(
+ static_cast<OptimizationTarget>(segment_id.value()));
+ if (in_use.has_value())
+ selected_segment.in_use = in_use.value();
+ if (selection_time.has_value())
+ selected_segment.selection_time = selection_time.value();
+
+ return selected_segment;
+}
+
} // namespace segmentation_platform
diff --git a/components/segmentation_platform/internal/selection/segmentation_result_prefs.h b/components/segmentation_platform/internal/selection/segmentation_result_prefs.h
index 93487a3..c82080e 100644
--- a/components/segmentation_platform/internal/selection/segmentation_result_prefs.h
+++ b/components/segmentation_platform/internal/selection/segmentation_result_prefs.h
@@ -11,17 +11,25 @@
using optimization_guide::proto::OptimizationTarget;
+class PrefService;
+
namespace segmentation_platform {
// Struct containing information about the selected segment. Convenient for
// reading and writing to prefs.
struct SelectedSegment {
public:
- OptimizationTarget segment_id;
- base::Time selection_time;
- bool in_use;
-
explicit SelectedSegment(OptimizationTarget segment_id);
+ ~SelectedSegment();
+
+ // The segment selection result.
+ OptimizationTarget segment_id;
+
+ // The time when the segment was selected.
+ base::Time selection_time;
+
+ // Whether or not the segment selection result is in use.
+ bool in_use;
};
// Stores the result of segmentation into prefs for faster lookup. The result
@@ -30,15 +38,26 @@
// selected segment has started to be used by clients.
class SegmentationResultPrefs {
public:
+ explicit SegmentationResultPrefs(PrefService* pref_service);
virtual ~SegmentationResultPrefs() = default;
+ // Disallow copy/assign.
+ SegmentationResultPrefs(const SegmentationResultPrefs& other) = delete;
+ SegmentationResultPrefs operator=(const SegmentationResultPrefs& other) =
+ delete;
+
// Writes the selected segment to prefs. Deletes the previous results if
// |selected_segment| is empty.
virtual void SaveSegmentationResultToPref(
- const absl::optional<SelectedSegment>& selected_segment) = 0;
+ const std::string& result_key,
+ const absl::optional<SelectedSegment>& selected_segment);
// Reads the selected segment from pref, if any.
- virtual absl::optional<SelectedSegment> ReadSegmentationResultFromPref() = 0;
+ virtual absl::optional<SelectedSegment> ReadSegmentationResultFromPref(
+ const std::string& result_key);
+
+ private:
+ PrefService* prefs_;
};
} // namespace segmentation_platform
diff --git a/components/segmentation_platform/internal/selection/segmentation_result_prefs_unittest.cc b/components/segmentation_platform/internal/selection/segmentation_result_prefs_unittest.cc
new file mode 100644
index 0000000..c105671
--- /dev/null
+++ b/components/segmentation_platform/internal/selection/segmentation_result_prefs_unittest.cc
@@ -0,0 +1,95 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
+
+#include "base/run_loop.h"
+#include "base/test/task_environment.h"
+#include "components/prefs/pref_registry_simple.h"
+#include "components/prefs/testing_pref_service.h"
+#include "components/segmentation_platform/internal/constants.h"
+#include "components/segmentation_platform/public/segmentation_platform_service.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace segmentation_platform {
+
+class SegmentationResultPrefsTest : public testing::Test {
+ public:
+ SegmentationResultPrefsTest() = default;
+ ~SegmentationResultPrefsTest() override = default;
+
+ void SetUp() override {
+ result_prefs_ = std::make_unique<SegmentationResultPrefs>(&pref_service_);
+ pref_service_.registry()->RegisterDictionaryPref(kSegmentationResultPref);
+ }
+
+ protected:
+ base::test::TaskEnvironment task_environment_;
+ TestingPrefServiceSimple pref_service_;
+ std::unique_ptr<SegmentationResultPrefs> result_prefs_;
+};
+
+TEST_F(SegmentationResultPrefsTest, WriteResultAndRead) {
+ std::string result_key = "some_key";
+ // Start test with no result.
+ absl::optional<SelectedSegment> current_result =
+ result_prefs_->ReadSegmentationResultFromPref(result_key);
+ EXPECT_FALSE(current_result.has_value());
+
+ // Save a result. Verify by reading the result back.
+ OptimizationTarget segment_id =
+ OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+ SelectedSegment selected_segment(segment_id);
+ result_prefs_->SaveSegmentationResultToPref(result_key, selected_segment);
+ current_result = result_prefs_->ReadSegmentationResultFromPref(result_key);
+ EXPECT_TRUE(current_result.has_value());
+ EXPECT_EQ(segment_id, current_result->segment_id);
+ EXPECT_FALSE(current_result->in_use);
+ EXPECT_EQ(base::Time(), current_result->selection_time);
+
+ // Overwrite the result with a new segment.
+ selected_segment.segment_id =
+ OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
+ selected_segment.in_use = true;
+ base::Time now = base::Time::Now();
+ selected_segment.selection_time = now;
+ result_prefs_->SaveSegmentationResultToPref(result_key, selected_segment);
+ current_result = result_prefs_->ReadSegmentationResultFromPref(result_key);
+ EXPECT_TRUE(current_result.has_value());
+ EXPECT_EQ(selected_segment.segment_id, current_result->segment_id);
+ EXPECT_TRUE(current_result->in_use);
+ EXPECT_EQ(now, current_result->selection_time);
+
+ // Write another result with a different key. This shouldn't overwrite the
+ // first key.
+ std::string result_key2 = "some_key2";
+ selected_segment.segment_id =
+ OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE;
+ result_prefs_->SaveSegmentationResultToPref(result_key2, selected_segment);
+ current_result = result_prefs_->ReadSegmentationResultFromPref(result_key2);
+ EXPECT_TRUE(current_result.has_value());
+ EXPECT_EQ(selected_segment.segment_id, current_result->segment_id);
+
+ current_result = result_prefs_->ReadSegmentationResultFromPref(result_key);
+ EXPECT_TRUE(current_result.has_value());
+ EXPECT_EQ(OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
+ current_result->segment_id);
+
+ // Save empty result. It should delete the current result.
+ result_prefs_->SaveSegmentationResultToPref(result_key, absl::nullopt);
+ current_result = result_prefs_->ReadSegmentationResultFromPref(result_key);
+ EXPECT_FALSE(current_result.has_value());
+}
+
+TEST_F(SegmentationResultPrefsTest, CorruptedValue) {
+ std::string result_key = "some_key";
+ SelectedSegment selected_segment(static_cast<OptimizationTarget>(100));
+ result_prefs_->SaveSegmentationResultToPref(result_key, selected_segment);
+ absl::optional<SelectedSegment> current_result =
+ result_prefs_->ReadSegmentationResultFromPref(result_key);
+ EXPECT_TRUE(current_result.has_value());
+ EXPECT_EQ(100, current_result->segment_id);
+}
+} // namespace segmentation_platform
diff --git a/components/segmentation_platform/public/segmentation_platform_service.h b/components/segmentation_platform/public/segmentation_platform_service.h
index ad6da290..b472cb5 100644
--- a/components/segmentation_platform/public/segmentation_platform_service.h
+++ b/components/segmentation_platform/public/segmentation_platform_service.h
@@ -7,6 +7,8 @@
#include "components/keyed_service/core/keyed_service.h"
+class PrefRegistrySimple;
+
namespace segmentation_platform {
// The core class of segmentation platform that integrates all the required
@@ -19,6 +21,10 @@
SegmentationPlatformService(const SegmentationPlatformService&) = delete;
SegmentationPlatformService& operator=(const SegmentationPlatformService&) =
delete;
+
+ // Registers preferences used by this class in the provided |registry|. This
+ // should be called for the Profile registry.
+ static void RegisterProfilePrefs(PrefRegistrySimple* registry);
};
} // namespace segmentation_platform