[segmentation] Replace OptimizationTarget with SegmentId

Replaces the segmentation_platform internal code to use a new SegmentId
enum instead of OptimizationTarget. This lets us add segments that do
not need any smart logic to be the output of the selections.

SegmentId enum is added to the public API of segmentation_platform.
The existing enum values from OptimizationTarget are copied over,
with a warning to stop adding more enums until migration is done.

The public API of SegmentationPlatformService is unchanged and uses
OptimizationTarget, which is currently used by multiple clients.
The selector converts the SegmentId to OptimizationTarget to return
to the client. The public API will be updated in the next CL. This
one still updates the Config.

The optimization guide model wrapper converts the SegmentId to
OptimizationTarget to fetch the models.

All other changes are mechanical replacement of string
OptimizationTarget to SegmentId with namespace. There should be no
behavior changes or metrics changes due to this CL. Additionally
makes style fixes with namespace and using declarations.

BUG=1315459

Change-Id: I5cd70b83bb6feef893ff74b8289d0542b4b342c4
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3656802
Commit-Queue: Siddhartha S <[email protected]>
Reviewed-by: Shakti Sahu <[email protected]>
Reviewed-by: David Trainor <[email protected]>
Reviewed-by: Min Qin <[email protected]>
Cr-Commit-Position: refs/heads/main@{#1006045}
diff --git a/components/segmentation_platform/internal/BUILD.gn b/components/segmentation_platform/internal/BUILD.gn
index baa208e4..4157fc05 100644
--- a/components/segmentation_platform/internal/BUILD.gn
+++ b/components/segmentation_platform/internal/BUILD.gn
@@ -99,6 +99,8 @@
     "scheduler/model_execution_scheduler.h",
     "scheduler/model_execution_scheduler_impl.cc",
     "scheduler/model_execution_scheduler_impl.h",
+    "segment_id_convertor.cc",
+    "segment_id_convertor.h",
     "segmentation_platform_service_impl.cc",
     "segmentation_platform_service_impl.h",
     "segmentation_ukm_helper.cc",
@@ -150,6 +152,7 @@
     "//components/prefs",
     "//components/segmentation_platform/internal/proto",
     "//components/segmentation_platform/public",
+    "//components/segmentation_platform/public/proto",
     "//components/ukm:ukm_recorder",
     "//services/metrics/public/cpp:metrics_cpp",
     "//services/metrics/public/cpp:ukm_builders",
diff --git a/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc b/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
index 435f317..f63cf16 100644
--- a/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
+++ b/components/segmentation_platform/internal/data_collection/training_data_collector_impl.cc
@@ -55,7 +55,7 @@
 
 // Find the segmentation key from the configs that contains the segment ID.
 std::string GetSegmentationKey(std::vector<std::unique_ptr<Config>>* configs,
-                               OptimizationTarget segment_id) {
+                               SegmentId segment_id) {
   if (!configs)
     return std::string();
 
@@ -107,7 +107,7 @@
   histogram_signal_handler_->AddObserver(this);
 
   DCHECK(segments);
-  const base::flat_set<OptimizationTarget>& allowed_ids =
+  const base::flat_set<SegmentId>& allowed_ids =
       SegmentationUkmHelper::GetInstance()->allowed_segment_ids();
   for (const auto& segment : *segments) {
     // Skip the segment if it is not in allowed list.
@@ -156,13 +156,12 @@
   // Report training data for all models that are interested in
   // |histogram_name| as output.
   if (it != immediate_collection_histograms_.end()) {
-    std::vector<OptimizationTarget> optimization_targets(it->second.begin(),
-                                                         it->second.end());
+    std::vector<SegmentId> segment_ids(it->second.begin(), it->second.end());
     auto param = absl::make_optional<ImmediaCollectionParam>();
     param->output_metric_hash = hash;
     param->output_value = static_cast<float>(sample);
     segment_info_database_->GetSegmentInfoForSegments(
-        optimization_targets,
+        segment_ids,
         base::BindOnce(&TrainingDataCollectorImpl::ReportForSegmentsInfoList,
                        weak_ptr_factory_.GetWeakPtr(), std::move(param)));
   }
@@ -336,8 +335,8 @@
   base::Time next_collection_time = GetNextReportTime(last_collection_time);
   if (clock_->Now() >= next_collection_time) {
     segment_info_database_->GetSegmentInfoForSegments(
-        std::vector<OptimizationTarget>(continuous_collection_segments_.begin(),
-                                        continuous_collection_segments_.end()),
+        std::vector<SegmentId>(continuous_collection_segments_.begin(),
+                               continuous_collection_segments_.end()),
         base::BindOnce(&TrainingDataCollectorImpl::ReportForSegmentsInfoList,
                        weak_ptr_factory_.GetWeakPtr(), absl::nullopt));
   }
diff --git a/components/segmentation_platform/internal/data_collection/training_data_collector_impl.h b/components/segmentation_platform/internal/data_collection/training_data_collector_impl.h
index b1bfc95..6a7000e3 100644
--- a/components/segmentation_platform/internal/data_collection/training_data_collector_impl.h
+++ b/components/segmentation_platform/internal/data_collection/training_data_collector_impl.h
@@ -13,16 +13,16 @@
 #include "base/memory/raw_ptr.h"
 #include "base/memory/weak_ptr.h"
 #include "base/metrics/histogram_base.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/data_collection/training_data_collector.h"
 #include "components/segmentation_platform/internal/database/segment_info_database.h"
 #include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
 #include "components/segmentation_platform/internal/signals/histogram_signal_handler.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
-using optimization_guide::proto::OptimizationTarget;
-
 namespace segmentation_platform {
+using proto::SegmentId;
+
 struct Config;
 class SegmentationResultPrefs;
 
@@ -88,12 +88,11 @@
   // Hash of histograms for immediate training data collection. When any
   // histogram hash contained in the map is recorded, a UKM message is reported
   // right away.
-  base::flat_map<uint64_t,
-                 base::flat_set<optimization_guide::proto::OptimizationTarget>>
+  base::flat_map<uint64_t, base::flat_set<proto::SegmentId>>
       immediate_collection_histograms_;
 
   // A list of segment IDs that needs to report metrics continuously.
-  std::set<OptimizationTarget> continuous_collection_segments_;
+  std::set<SegmentId> continuous_collection_segments_;
 
   base::WeakPtrFactory<TrainingDataCollectorImpl> weak_ptr_factory_{this};
 };
diff --git a/components/segmentation_platform/internal/data_collection/training_data_collector_impl_unittest.cc b/components/segmentation_platform/internal/data_collection/training_data_collector_impl_unittest.cc
index 1e8ace2ab..41d30d6 100644
--- a/components/segmentation_platform/internal/data_collection/training_data_collector_impl_unittest.cc
+++ b/components/segmentation_platform/internal/data_collection/training_data_collector_impl_unittest.cc
@@ -30,6 +30,9 @@
 #include "services/metrics/public/cpp/ukm_builders.h"
 #include "testing/gtest/include/gtest/gtest.h"
 
+namespace segmentation_platform {
+namespace {
+
 using ::base::test::RunOnceCallback;
 using ::testing::_;
 using ::testing::NiceMock;
@@ -38,18 +41,15 @@
     ::ukm::builders::Segmentation_ModelExecution;
 
 constexpr auto kTestOptimizationTarget0 =
-    OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+    SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
 constexpr auto kTestOptimizationTarget1 =
-    OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
+    SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
 constexpr char kHistogramName0[] = "histogram0";
 constexpr char kHistogramName1[] = "histogram1";
 constexpr char kSegmentationKey[] = "test_key";
 constexpr int64_t kModelVersion = 123;
 constexpr int kSample = 1;
 
-namespace segmentation_platform {
-namespace {
-
 class TrainingDataCollectorImplTest : public ::testing::Test {
  public:
   TrainingDataCollectorImplTest() = default;
@@ -82,13 +82,13 @@
     configs_.emplace_back(std::make_unique<Config>());
     configs_[0]->segmentation_key = kSegmentationKey;
     configs_[0]->segment_ids.push_back(
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
+        SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
     configs_[0]->segment_ids.push_back(
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+        SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
 
     SegmentationResultPrefs result_prefs(&prefs_);
     SelectedSegment selected_segment(
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+        SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
     selected_segment.selection_time = base::Time::Now() - base::Days(1);
     result_prefs.SaveSegmentationResultToPref(kSegmentationKey,
                                               selected_segment);
@@ -125,9 +125,8 @@
     return segment_info;
   }
 
-  proto::SegmentInfo* CreateSegment(OptimizationTarget optimization_target) {
-    auto* segment_info =
-        test_segment_db()->FindOrCreateSegment(optimization_target);
+  proto::SegmentInfo* CreateSegment(SegmentId segment_id) {
+    auto* segment_info = test_segment_db()->FindOrCreateSegment(segment_id);
     auto* model_metadata = segment_info->mutable_model_metadata();
     model_metadata->set_time_unit(proto::TimeUnit::DAY);
     model_metadata->set_signal_storage_length(7);
@@ -341,7 +340,7 @@
       {kTestOptimizationTarget0, kModelVersion,
        SegmentationUkmHelper::FloatToInt64(1.f),
        SegmentationUkmHelper::FloatToInt64(0.6f),
-       OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
+       SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
        base::Days(1).InSeconds(), SegmentationUkmHelper::FloatToInt64(2.f),
        SegmentationUkmHelper::FloatToInt64(3.f)});
 }
diff --git a/components/segmentation_platform/internal/database/database_maintenance_impl.cc b/components/segmentation_platform/internal/database/database_maintenance_impl.cc
index aa35a1de..52b2dfbe 100644
--- a/components/segmentation_platform/internal/database/database_maintenance_impl.cc
+++ b/components/segmentation_platform/internal/database/database_maintenance_impl.cc
@@ -90,7 +90,7 @@
 };
 
 DatabaseMaintenanceImpl::DatabaseMaintenanceImpl(
-    const base::flat_set<OptimizationTarget>& segment_ids,
+    const base::flat_set<SegmentId>& segment_ids,
     base::Clock* clock,
     SegmentInfoDatabase* segment_info_database,
     SignalDatabase* signal_database,
@@ -106,8 +106,7 @@
 DatabaseMaintenanceImpl::~DatabaseMaintenanceImpl() = default;
 
 void DatabaseMaintenanceImpl::ExecuteMaintenanceTasks() {
-  std::vector<OptimizationTarget> segment_ids(segment_ids_.begin(),
-                                              segment_ids_.end());
+  std::vector<SegmentId> segment_ids(segment_ids_.begin(), segment_ids_.end());
   default_model_manager_->GetAllSegmentInfoFromBothModels(
       segment_ids, segment_info_database_,
       base::BindOnce(&DatabaseMaintenanceImpl::OnSegmentInfoCallback,
diff --git a/components/segmentation_platform/internal/database/database_maintenance_impl.h b/components/segmentation_platform/internal/database/database_maintenance_impl.h
index bb6d2f77..d8d67db5a 100644
--- a/components/segmentation_platform/internal/database/database_maintenance_impl.h
+++ b/components/segmentation_platform/internal/database/database_maintenance_impl.h
@@ -15,19 +15,19 @@
 #include "base/memory/raw_ptr.h"
 #include "base/memory/weak_ptr.h"
 #include "base/time/time.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/database_maintenance.h"
 #include "components/segmentation_platform/internal/execution/default_model_manager.h"
 #include "components/segmentation_platform/internal/proto/types.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 
 namespace base {
 class Clock;
 class Time;
 }  // namespace base
 
-using optimization_guide::proto::OptimizationTarget;
-
 namespace segmentation_platform {
+using proto::SegmentId;
+
 class DefaultModelManager;
 class SegmentInfoDatabase;
 class SignalDatabase;
@@ -40,13 +40,12 @@
   using SignalIdentifier = std::pair<uint64_t, proto::SignalType>;
   using CleanupItem = std::tuple<uint64_t, proto::SignalType, base::Time>;
 
-  explicit DatabaseMaintenanceImpl(
-      const base::flat_set<OptimizationTarget>& segment_ids,
-      base::Clock* clock,
-      SegmentInfoDatabase* segment_info_database,
-      SignalDatabase* signal_database,
-      SignalStorageConfig* signal_storage_config,
-      DefaultModelManager* default_model_manager);
+  explicit DatabaseMaintenanceImpl(const base::flat_set<SegmentId>& segment_ids,
+                                   base::Clock* clock,
+                                   SegmentInfoDatabase* segment_info_database,
+                                   SignalDatabase* signal_database,
+                                   SignalStorageConfig* signal_storage_config,
+                                   DefaultModelManager* default_model_manager);
   ~DatabaseMaintenanceImpl() override;
 
   // DatabaseMaintenance overrides.
@@ -88,7 +87,7 @@
   void CompactSamplesDone(base::OnceClosure next_action);
 
   // Input.
-  base::flat_set<OptimizationTarget> segment_ids_;
+  base::flat_set<SegmentId> segment_ids_;
   raw_ptr<base::Clock> clock_;
 
   // Databases.
diff --git a/components/segmentation_platform/internal/database/database_maintenance_impl_unittest.cc b/components/segmentation_platform/internal/database/database_maintenance_impl_unittest.cc
index 8c7bc74..ce73d34 100644
--- a/components/segmentation_platform/internal/database/database_maintenance_impl_unittest.cc
+++ b/components/segmentation_platform/internal/database/database_maintenance_impl_unittest.cc
@@ -15,7 +15,6 @@
 #include "base/test/task_environment.h"
 #include "base/threading/thread_task_runner_handle.h"
 #include "base/time/time.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/mock_signal_database.h"
 #include "components/segmentation_platform/internal/database/mock_signal_storage_config.h"
 #include "components/segmentation_platform/internal/database/signal_storage_config.h"
@@ -24,6 +23,7 @@
 #include "components/segmentation_platform/internal/proto/aggregation.pb.h"
 #include "components/segmentation_platform/internal/proto/types.pb.h"
 #include "components/segmentation_platform/public/config.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "testing/gmock/include/gmock/gmock.h"
 #include "testing/gtest/include/gtest/gtest.h"
 
@@ -31,9 +31,9 @@
 using ::testing::_;
 using ::testing::SetArgReferee;
 
-using optimization_guide::proto::OptimizationTarget;
-
 namespace segmentation_platform {
+
+using SegmentId = proto::SegmentId;
 using SignalType = proto::SignalType;
 using Aggregation = proto::Aggregation;
 using SignalIdentifier = std::pair<uint64_t, SignalType>;
@@ -46,7 +46,7 @@
 std::string kTestSegmentationKey = "some_key";
 
 struct SignalData {
-  OptimizationTarget target;
+  SegmentId target;
   proto::SignalType signal_type;
   std::string name;
   uint64_t name_hash;
@@ -64,11 +64,11 @@
 class TestDefaultModelManager : public DefaultModelManager {
  public:
   TestDefaultModelManager()
-      : DefaultModelManager(nullptr, std::vector<OptimizationTarget>()) {}
+      : DefaultModelManager(nullptr, std::vector<SegmentId>()) {}
   ~TestDefaultModelManager() override = default;
 
   void GetAllSegmentInfoFromDefaultModel(
-      const std::vector<OptimizationTarget>& segment_ids,
+      const std::vector<SegmentId>& segment_ids,
       MultipleSegmentInfoCallback callback) override {
     base::ThreadTaskRunnerHandle::Get()->PostTask(
         FROM_HERE, base::BindOnce(std::move(callback),
@@ -76,7 +76,7 @@
   }
 
   void GetAllSegmentInfoFromBothModels(
-      const std::vector<OptimizationTarget>& segment_ids,
+      const std::vector<SegmentId>& segment_ids,
       SegmentInfoDatabase* segment_database,
       MultipleSegmentInfoCallback callback) override {
     segment_database->GetSegmentInfoForSegments(
@@ -107,9 +107,9 @@
     segment_info_database_ = std::make_unique<test::TestSegmentInfoDatabase>();
     signal_database_ = std::make_unique<MockSignalDatabase>();
     signal_storage_config_ = std::make_unique<MockSignalStorageConfig>();
-    base::flat_set<OptimizationTarget> segment_ids = {
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE};
+    base::flat_set<SegmentId> segment_ids = {
+        SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
+        SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE};
     default_model_manager_ = std::make_unique<TestDefaultModelManager>();
     database_maintenance_ = std::make_unique<DatabaseMaintenanceImpl>(
         segment_ids, &clock_, segment_info_database_.get(),
@@ -181,13 +181,13 @@
 
 TEST_F(DatabaseMaintenanceImplTest, ExecuteMaintenanceTasks) {
   std::vector<SignalData> signal_datas = {
-      {OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
+      {SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
        SignalType::HISTOGRAM_VALUE, "Foo", base::HashMetricName("Foo"), 44, 1,
        Aggregation::COUNT, clock_.Now() - base::Days(10), true},
-      {OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
+      {SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
        SignalType::HISTOGRAM_ENUM, "Bar", base::HashMetricName("Bar"), 33, 1,
        Aggregation::COUNT, clock_.Now() - base::Days(5), true},
-      {OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
+      {SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
        SignalType::USER_ACTION, "Failed", base::HashMetricName("Failed"), 22, 1,
        Aggregation::COUNT, clock_.Now() - base::Days(1), false},
   };
diff --git a/components/segmentation_platform/internal/database/segment_info_database.cc b/components/segmentation_platform/internal/database/segment_info_database.cc
index 12a3f46..3cad35a 100644
--- a/components/segmentation_platform/internal/database/segment_info_database.cc
+++ b/components/segmentation_platform/internal/database/segment_info_database.cc
@@ -12,7 +12,7 @@
 
 namespace {
 
-std::string ToString(OptimizationTarget segment_id) {
+std::string ToString(SegmentId segment_id) {
   return base::NumberToString(static_cast<int>(segment_id));
 }
 
@@ -53,10 +53,10 @@
 }
 
 void SegmentInfoDatabase::GetSegmentInfoForSegments(
-    const std::vector<OptimizationTarget>& segment_ids,
+    const std::vector<SegmentId>& segment_ids,
     MultipleSegmentInfoCallback callback) {
   std::vector<std::string> keys;
-  for (OptimizationTarget target : segment_ids)
+  for (SegmentId target : segment_ids)
     keys.emplace_back(ToString(target));
 
   database_->LoadEntriesWithFilter(
@@ -69,7 +69,7 @@
                      weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
 }
 
-void SegmentInfoDatabase::GetSegmentInfo(OptimizationTarget segment_id,
+void SegmentInfoDatabase::GetSegmentInfo(SegmentId segment_id,
                                          SegmentInfoCallback callback) {
   database_->GetEntry(
       ToString(segment_id),
@@ -86,7 +86,7 @@
 }
 
 void SegmentInfoDatabase::UpdateSegment(
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     absl::optional<proto::SegmentInfo> segment_info,
     SuccessCallback callback) {
   auto entries_to_save = std::make_unique<
@@ -104,7 +104,7 @@
 }
 
 void SegmentInfoDatabase::SaveSegmentResult(
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     absl::optional<proto::PredictionResult> result,
     SuccessCallback callback) {
   GetSegmentInfo(
diff --git a/components/segmentation_platform/internal/database/segment_info_database.h b/components/segmentation_platform/internal/database/segment_info_database.h
index 2c4ab12..41dc38a 100644
--- a/components/segmentation_platform/internal/database/segment_info_database.h
+++ b/components/segmentation_platform/internal/database/segment_info_database.h
@@ -10,15 +10,15 @@
 
 #include "base/callback.h"
 #include "components/leveldb_proto/public/proto_database.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
 #include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
-using optimization_guide::proto::OptimizationTarget;
-
 namespace segmentation_platform {
 
+using proto::SegmentId;
+
 namespace proto {
 class SegmentInfo;
 class PredictionResult;
@@ -29,8 +29,7 @@
 class SegmentInfoDatabase {
  public:
   using SuccessCallback = base::OnceCallback<void(bool)>;
-  using SegmentInfoList =
-      std::vector<std::pair<OptimizationTarget, proto::SegmentInfo>>;
+  using SegmentInfoList = std::vector<std::pair<SegmentId, proto::SegmentInfo>>;
   using MultipleSegmentInfoCallback =
       base::OnceCallback<void(std::unique_ptr<SegmentInfoList>)>;
   using SegmentInfoCallback =
@@ -52,24 +51,24 @@
 
   // Called to get metadata for a given list of segments.
   virtual void GetSegmentInfoForSegments(
-      const std::vector<OptimizationTarget>& segment_ids,
+      const std::vector<SegmentId>& segment_ids,
       MultipleSegmentInfoCallback callback);
 
   // Called to get the metadata for a given segment.
-  virtual void GetSegmentInfo(OptimizationTarget segment_id,
+  virtual void GetSegmentInfo(SegmentId segment_id,
                               SegmentInfoCallback callback);
 
   // Called to save or update metadata for a segment. The previous data is
   // overwritten. If |segment_info| is empty, the segment will be deleted.
   // TODO(shaktisahu): How does the client know if a segment is to be deleted?
-  virtual void UpdateSegment(OptimizationTarget segment_id,
+  virtual void UpdateSegment(SegmentId segment_id,
                              absl::optional<proto::SegmentInfo> segment_info,
                              SuccessCallback callback);
 
   // Called to write the model execution results for a given segment. It will
   // first read the currently stored result, and then overwrite it with
   // |result|. If |result| is null, the existing result will be deleted.
-  virtual void SaveSegmentResult(OptimizationTarget segment_id,
+  virtual void SaveSegmentResult(SegmentId segment_id,
                                  absl::optional<proto::PredictionResult> result,
                                  SuccessCallback callback);
 
diff --git a/components/segmentation_platform/internal/database/segment_info_database_unittest.cc b/components/segmentation_platform/internal/database/segment_info_database_unittest.cc
index 9d492345..1dbe9e16 100644
--- a/components/segmentation_platform/internal/database/segment_info_database_unittest.cc
+++ b/components/segmentation_platform/internal/database/segment_info_database_unittest.cc
@@ -18,16 +18,15 @@
 namespace {
 
 // Test Ids.
-const OptimizationTarget kSegmentId =
-    OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
-const OptimizationTarget kSegmentId2 =
-    OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
+const SegmentId kSegmentId =
+    SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+const SegmentId kSegmentId2 = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
 
-std::string ToString(OptimizationTarget segment_id) {
+std::string ToString(SegmentId segment_id) {
   return base::NumberToString(static_cast<int>(segment_id));
 }
 
-proto::SegmentInfo CreateSegment(OptimizationTarget segment_id,
+proto::SegmentInfo CreateSegment(SegmentId segment_id,
                                  absl::optional<int> result = absl::nullopt) {
   proto::SegmentInfo info;
   info.set_segment_id(segment_id);
@@ -71,14 +70,13 @@
     segment_db_.reset();
   }
 
-  void VerifyDb(std::vector<OptimizationTarget> expected_ids) {
+  void VerifyDb(std::vector<SegmentId> expected_ids) {
     EXPECT_EQ(expected_ids.size(), db_entries_.size());
     for (auto segment_id : expected_ids)
       EXPECT_TRUE(db_entries_.find(ToString(segment_id)) != db_entries_.end());
   }
 
-  void WriteResult(OptimizationTarget segment_id,
-                   absl::optional<float> result) {
+  void WriteResult(SegmentId segment_id, absl::optional<float> result) {
     proto::PredictionResult prediction_result;
     if (result.has_value())
       prediction_result.set_result(result.value());
@@ -92,8 +90,7 @@
     db_->UpdateCallback(true);
   }
 
-  void VerifyResult(OptimizationTarget segment_id,
-                    absl::optional<float> result) {
+  void VerifyResult(SegmentId segment_id, absl::optional<float> result) {
     segment_db_->GetSegmentInfo(
         segment_id, base::BindOnce(&SegmentInfoDatabaseTest::OnGetSegment,
                                    base::Unretained(this)));
diff --git a/components/segmentation_platform/internal/database/storage_service.cc b/components/segmentation_platform/internal/database/storage_service.cc
index e4618f8..90052cfa 100644
--- a/components/segmentation_platform/internal/database/storage_service.cc
+++ b/components/segmentation_platform/internal/database/storage_service.cc
@@ -31,8 +31,7 @@
     scoped_refptr<base::SequencedTaskRunner> task_runner,
     base::Clock* clock,
     UkmDataManager* ukm_data_manager,
-    base::flat_set<optimization_guide::proto::OptimizationTarget>
-        all_segment_ids,
+    base::flat_set<proto::SegmentId> all_segment_ids,
     ModelProviderFactory* model_provider_factory)
     : StorageService(
           db_provider->GetDB<proto::SegmentInfo>(
@@ -60,13 +59,12 @@
         signal_storage_config_db,
     base::Clock* clock,
     UkmDataManager* ukm_data_manager,
-    base::flat_set<optimization_guide::proto::OptimizationTarget>
-        all_segment_ids,
+    base::flat_set<proto::SegmentId> all_segment_ids,
     ModelProviderFactory* model_provider_factory)
     : default_model_manager_(std::make_unique<DefaultModelManager>(
           model_provider_factory,
-          std::vector<OptimizationTarget>(all_segment_ids.begin(),
-                                          all_segment_ids.end()))),
+          std::vector<SegmentId>(all_segment_ids.begin(),
+                                 all_segment_ids.end()))),
       segment_info_database_(
           std::make_unique<SegmentInfoDatabase>(std::move(segment_db))),
       signal_database_(
diff --git a/components/segmentation_platform/internal/database/storage_service.h b/components/segmentation_platform/internal/database/storage_service.h
index d5dbd1a1..258bc0dc 100644
--- a/components/segmentation_platform/internal/database/storage_service.h
+++ b/components/segmentation_platform/internal/database/storage_service.h
@@ -12,7 +12,7 @@
 #include "base/containers/flat_set.h"
 #include "base/memory/raw_ptr.h"
 #include "components/leveldb_proto/public/proto_database.h"
-#include "components/optimization_guide/proto/models.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
 namespace base {
@@ -65,8 +65,7 @@
                  scoped_refptr<base::SequencedTaskRunner> task_runner,
                  base::Clock* clock,
                  UkmDataManager* ukm_data_manager,
-                 base::flat_set<optimization_guide::proto::OptimizationTarget>
-                     all_segment_ids,
+                 base::flat_set<proto::SegmentId> all_segment_ids,
                  ModelProviderFactory* model_provider_factory);
 
   // For tests:
@@ -79,8 +78,7 @@
           signal_storage_config_db,
       base::Clock* clock,
       UkmDataManager* ukm_data_manager,
-      base::flat_set<optimization_guide::proto::OptimizationTarget>
-          all_segment_ids,
+      base::flat_set<proto::SegmentId> all_segment_ids,
       ModelProviderFactory* model_provider_factory);
 
   // For tests:
diff --git a/components/segmentation_platform/internal/database/test_segment_info_database.cc b/components/segmentation_platform/internal/database/test_segment_info_database.cc
index 3fbd7c3..1457adb 100644
--- a/components/segmentation_platform/internal/database/test_segment_info_database.cc
+++ b/components/segmentation_platform/internal/database/test_segment_info_database.cc
@@ -8,11 +8,11 @@
 
 #include "base/containers/contains.h"
 #include "base/metrics/metrics_hashes.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/constants.h"
 #include "components/segmentation_platform/internal/metadata/metadata_writer.h"
 #include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
 #include "components/segmentation_platform/internal/proto/types.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
 namespace segmentation_platform::test {
@@ -33,7 +33,7 @@
 }
 
 void TestSegmentInfoDatabase::GetSegmentInfoForSegments(
-    const std::vector<OptimizationTarget>& segment_ids,
+    const std::vector<SegmentId>& segment_ids,
     MultipleSegmentInfoCallback callback) {
   auto result = std::make_unique<SegmentInfoDatabase::SegmentInfoList>();
   for (const auto& pair : segment_infos_) {
@@ -43,13 +43,13 @@
   std::move(callback).Run(std::move(result));
 }
 
-void TestSegmentInfoDatabase::GetSegmentInfo(OptimizationTarget segment_id,
+void TestSegmentInfoDatabase::GetSegmentInfo(SegmentId segment_id,
                                              SegmentInfoCallback callback) {
-  auto result = std::find_if(
-      segment_infos_.begin(), segment_infos_.end(),
-      [segment_id](std::pair<OptimizationTarget, proto::SegmentInfo> pair) {
-        return pair.first == segment_id;
-      });
+  auto result =
+      std::find_if(segment_infos_.begin(), segment_infos_.end(),
+                   [segment_id](std::pair<SegmentId, proto::SegmentInfo> pair) {
+                     return pair.first == segment_id;
+                   });
 
   std::move(callback).Run(result == segment_infos_.end()
                               ? absl::nullopt
@@ -57,7 +57,7 @@
 }
 
 void TestSegmentInfoDatabase::UpdateSegment(
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     absl::optional<proto::SegmentInfo> segment_info,
     SuccessCallback callback) {
   if (segment_info.has_value()) {
@@ -67,8 +67,7 @@
     // Delete the segment.
     auto new_end = std::remove_if(
         segment_infos_.begin(), segment_infos_.end(),
-        [segment_id](
-            const std::pair<OptimizationTarget, proto::SegmentInfo>& pair) {
+        [segment_id](const std::pair<SegmentId, proto::SegmentInfo>& pair) {
           return pair.first == segment_id;
         });
     segment_infos_.erase(new_end, segment_infos_.end());
@@ -77,7 +76,7 @@
 }
 
 void TestSegmentInfoDatabase::SaveSegmentResult(
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     absl::optional<proto::PredictionResult> result,
     SuccessCallback callback) {
   proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
@@ -92,7 +91,7 @@
 }
 
 void TestSegmentInfoDatabase::AddUserActionFeature(
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     const std::string& name,
     uint64_t bucket_count,
     uint64_t tensor_length,
@@ -111,7 +110,7 @@
 }
 
 void TestSegmentInfoDatabase::AddHistogramValueFeature(
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     const std::string& name,
     uint64_t bucket_count,
     uint64_t tensor_length,
@@ -130,7 +129,7 @@
 }
 
 void TestSegmentInfoDatabase::AddHistogramEnumFeature(
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     const std::string& name,
     uint64_t bucket_count,
     uint64_t tensor_length,
@@ -151,7 +150,7 @@
 }
 
 void TestSegmentInfoDatabase::AddSqlFeature(
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     const MetadataWriter::SqlFeature& feature) {
   proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
   MetadataWriter writer(info->mutable_model_metadata());
@@ -159,7 +158,7 @@
   writer.AddSqlFeatures(features, 1);
 }
 
-void TestSegmentInfoDatabase::AddPredictionResult(OptimizationTarget segment_id,
+void TestSegmentInfoDatabase::AddPredictionResult(SegmentId segment_id,
                                                   float score,
                                                   base::Time timestamp) {
   proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
@@ -170,7 +169,7 @@
 }
 
 void TestSegmentInfoDatabase::AddDiscreteMapping(
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     const float mappings[][2],
     int num_pairs,
     const std::string& discrete_mapping_key) {
@@ -186,7 +185,7 @@
   }
 }
 
-void TestSegmentInfoDatabase::SetBucketDuration(OptimizationTarget segment_id,
+void TestSegmentInfoDatabase::SetBucketDuration(SegmentId segment_id,
                                                 uint64_t bucket_duration,
                                                 proto::TimeUnit time_unit) {
   proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
@@ -195,7 +194,7 @@
 }
 
 proto::SegmentInfo* TestSegmentInfoDatabase::FindOrCreateSegment(
-    OptimizationTarget segment_id) {
+    SegmentId segment_id) {
   proto::SegmentInfo* info = nullptr;
   for (auto& pair : segment_infos_) {
     if (pair.first == segment_id) {
diff --git a/components/segmentation_platform/internal/database/test_segment_info_database.h b/components/segmentation_platform/internal/database/test_segment_info_database.h
index b50d9c41..3b6d388 100644
--- a/components/segmentation_platform/internal/database/test_segment_info_database.h
+++ b/components/segmentation_platform/internal/database/test_segment_info_database.h
@@ -27,53 +27,52 @@
   // SegmentInfoDatabase overrides.
   void Initialize(SuccessCallback callback) override;
   void GetAllSegmentInfo(MultipleSegmentInfoCallback callback) override;
-  void GetSegmentInfoForSegments(
-      const std::vector<OptimizationTarget>& segment_ids,
-      MultipleSegmentInfoCallback callback) override;
-  void GetSegmentInfo(OptimizationTarget segment_id,
+  void GetSegmentInfoForSegments(const std::vector<SegmentId>& segment_ids,
+                                 MultipleSegmentInfoCallback callback) override;
+  void GetSegmentInfo(SegmentId segment_id,
                       SegmentInfoCallback callback) override;
-  void UpdateSegment(OptimizationTarget segment_id,
+  void UpdateSegment(SegmentId segment_id,
                      absl::optional<proto::SegmentInfo> segment_info,
                      SuccessCallback callback) override;
-  void SaveSegmentResult(OptimizationTarget segment_id,
+  void SaveSegmentResult(SegmentId segment_id,
                          absl::optional<proto::PredictionResult> result,
                          SuccessCallback callback) override;
 
   // Test helper methods.
-  void AddUserActionFeature(OptimizationTarget segment_id,
+  void AddUserActionFeature(SegmentId segment_id,
                             const std::string& user_action,
                             uint64_t bucket_count,
                             uint64_t tensor_length,
                             proto::Aggregation aggregation);
-  void AddHistogramValueFeature(OptimizationTarget segment_id,
+  void AddHistogramValueFeature(SegmentId segment_id,
                                 const std::string& histogram,
                                 uint64_t bucket_count,
                                 uint64_t tensor_length,
                                 proto::Aggregation aggregation);
-  void AddHistogramEnumFeature(OptimizationTarget segment_id,
+  void AddHistogramEnumFeature(SegmentId segment_id,
                                const std::string& histogram_name,
                                uint64_t bucket_count,
                                uint64_t tensor_length,
                                proto::Aggregation aggregation,
                                const std::vector<int32_t>& accepted_enum_ids);
-  void AddSqlFeature(OptimizationTarget segment_id,
+  void AddSqlFeature(SegmentId segment_id,
                      const MetadataWriter::SqlFeature& feature);
-  void AddPredictionResult(OptimizationTarget segment_id,
+  void AddPredictionResult(SegmentId segment_id,
                            float score,
                            base::Time timestamp);
-  void AddDiscreteMapping(OptimizationTarget segment_id,
+  void AddDiscreteMapping(SegmentId segment_id,
                           const float mappings[][2],
                           int num_pairs,
                           const std::string& discrete_mapping_key);
-  void SetBucketDuration(OptimizationTarget segment_id,
+  void SetBucketDuration(SegmentId segment_id,
                          uint64_t bucket_duration,
                          proto::TimeUnit time_unit);
 
   // Finds a segment with given |segment_id|. Creates one if it doesn't exist.
-  proto::SegmentInfo* FindOrCreateSegment(OptimizationTarget segment_id);
+  proto::SegmentInfo* FindOrCreateSegment(SegmentId segment_id);
 
  private:
-  std::vector<std::pair<OptimizationTarget, proto::SegmentInfo>> segment_infos_;
+  std::vector<std::pair<SegmentId, proto::SegmentInfo>> segment_infos_;
 };
 
 }  // namespace segmentation_platform::test
diff --git a/components/segmentation_platform/internal/execution/default_model_manager.cc b/components/segmentation_platform/internal/execution/default_model_manager.cc
index c12c96bd..460668c 100644
--- a/components/segmentation_platform/internal/execution/default_model_manager.cc
+++ b/components/segmentation_platform/internal/execution/default_model_manager.cc
@@ -14,9 +14,9 @@
 
 DefaultModelManager::DefaultModelManager(
     ModelProviderFactory* model_provider_factory,
-    const std::vector<OptimizationTarget>& segment_ids)
+    const std::vector<SegmentId>& segment_ids)
     : model_provider_factory_(model_provider_factory) {
-  for (OptimizationTarget segment_id : segment_ids) {
+  for (SegmentId segment_id : segment_ids) {
     std::unique_ptr<ModelProvider> provider =
         model_provider_factory->CreateDefaultProvider(segment_id);
     if (!provider)
@@ -28,8 +28,7 @@
 
 DefaultModelManager::~DefaultModelManager() = default;
 
-ModelProvider* DefaultModelManager::GetDefaultProvider(
-    OptimizationTarget segment_id) {
+ModelProvider* DefaultModelManager::GetDefaultProvider(SegmentId segment_id) {
   auto it = default_model_providers_.find(segment_id);
   if (it != default_model_providers_.end())
     return it->second.get();
@@ -37,21 +36,20 @@
 }
 
 void DefaultModelManager::GetAllSegmentInfoFromDefaultModel(
-    const std::vector<OptimizationTarget>& segment_ids,
+    const std::vector<SegmentId>& segment_ids,
     MultipleSegmentInfoCallback callback) {
   auto result = std::make_unique<SegmentInfoList>();
-  std::deque<OptimizationTarget> remaining_segment_ids(segment_ids.begin(),
-                                                       segment_ids.end());
+  std::deque<SegmentId> remaining_segment_ids(segment_ids.begin(),
+                                              segment_ids.end());
   GetNextSegmentInfoFromDefaultModel(
       std::move(result), std::move(remaining_segment_ids), std::move(callback));
 }
 
 void DefaultModelManager::GetNextSegmentInfoFromDefaultModel(
     std::unique_ptr<SegmentInfoList> result,
-    std::deque<OptimizationTarget> remaining_segment_ids,
+    std::deque<SegmentId> remaining_segment_ids,
     MultipleSegmentInfoCallback callback) {
-  OptimizationTarget segment_id =
-      OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN;
+  SegmentId segment_id = SegmentId::OPTIMIZATION_TARGET_UNKNOWN;
   ModelProvider* default_provider = nullptr;
 
   // Find the next available default provider.
@@ -78,9 +76,9 @@
 
 void DefaultModelManager::OnFetchDefaultModel(
     std::unique_ptr<SegmentInfoList> result,
-    std::deque<OptimizationTarget> remaining_segment_ids,
+    std::deque<SegmentId> remaining_segment_ids,
     MultipleSegmentInfoCallback callback,
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     proto::SegmentationModelMetadata metadata,
     int64_t model_version) {
   auto info = std::make_unique<SegmentInfoWrapper>();
@@ -95,7 +93,7 @@
 }
 
 void DefaultModelManager::GetAllSegmentInfoFromBothModels(
-    const std::vector<OptimizationTarget>& segment_ids,
+    const std::vector<SegmentId>& segment_ids,
     SegmentInfoDatabase* segment_database,
     MultipleSegmentInfoCallback callback) {
   segment_database->GetSegmentInfoForSegments(
@@ -106,7 +104,7 @@
 }
 
 void DefaultModelManager::OnGetAllSegmentInfoFromDatabase(
-    const std::vector<OptimizationTarget>& segment_ids,
+    const std::vector<SegmentId>& segment_ids,
     MultipleSegmentInfoCallback callback,
     std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_infos) {
   GetAllSegmentInfoFromDefaultModel(
@@ -137,7 +135,7 @@
 }
 
 void DefaultModelManager::SetDefaultProvidersForTesting(
-    std::map<OptimizationTarget, std::unique_ptr<ModelProvider>>&& providers) {
+    std::map<SegmentId, std::unique_ptr<ModelProvider>>&& providers) {
   default_model_providers_ = std::move(providers);
 }
 
diff --git a/components/segmentation_platform/internal/execution/default_model_manager.h b/components/segmentation_platform/internal/execution/default_model_manager.h
index 911a31cc..e0ca96b 100644
--- a/components/segmentation_platform/internal/execution/default_model_manager.h
+++ b/components/segmentation_platform/internal/execution/default_model_manager.h
@@ -20,9 +20,9 @@
 #include "components/segmentation_platform/public/model_provider.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
-using optimization_guide::proto::OptimizationTarget;
-
 namespace segmentation_platform {
+using proto::SegmentId;
+
 class SegmentInfoDatabase;
 
 // DefaultModelManager provides support to query all default models available.
@@ -31,7 +31,7 @@
 class DefaultModelManager {
  public:
   DefaultModelManager(ModelProviderFactory* model_provider_factory,
-                      const std::vector<OptimizationTarget>& segment_ids);
+                      const std::vector<SegmentId>& segment_ids);
   virtual ~DefaultModelManager();
 
   // Disallow copy/assign.
@@ -60,37 +60,37 @@
   // default model for a given set of segment IDs. The result can contain
   // the same segment ID multiple times.
   virtual void GetAllSegmentInfoFromBothModels(
-      const std::vector<OptimizationTarget>& segment_ids,
+      const std::vector<SegmentId>& segment_ids,
       SegmentInfoDatabase* segment_database,
       MultipleSegmentInfoCallback callback);
 
   // Called to get the segment info from the default model for a given set of
   // segment IDs.
   virtual void GetAllSegmentInfoFromDefaultModel(
-      const std::vector<OptimizationTarget>& segment_ids,
+      const std::vector<SegmentId>& segment_ids,
       MultipleSegmentInfoCallback callback);
 
   // Returns the default provider or `nulllptr` when unavailable.
-  ModelProvider* GetDefaultProvider(OptimizationTarget segment_id);
+  ModelProvider* GetDefaultProvider(SegmentId segment_id);
 
   void SetDefaultProvidersForTesting(
-      std::map<OptimizationTarget, std::unique_ptr<ModelProvider>>&& providers);
+      std::map<SegmentId, std::unique_ptr<ModelProvider>>&& providers);
 
  private:
   void GetNextSegmentInfoFromDefaultModel(
       std::unique_ptr<SegmentInfoList> result,
-      std::deque<OptimizationTarget> remaining_segment_ids,
+      std::deque<SegmentId> remaining_segment_ids,
       MultipleSegmentInfoCallback callback);
 
   void OnFetchDefaultModel(std::unique_ptr<SegmentInfoList> result,
-                           std::deque<OptimizationTarget> remaining_segment_ids,
+                           std::deque<SegmentId> remaining_segment_ids,
                            MultipleSegmentInfoCallback callback,
-                           OptimizationTarget segment_id,
+                           SegmentId segment_id,
                            proto::SegmentationModelMetadata metadata,
                            int64_t model_version);
 
   void OnGetAllSegmentInfoFromDatabase(
-      const std::vector<OptimizationTarget>& segment_ids,
+      const std::vector<SegmentId>& segment_ids,
       MultipleSegmentInfoCallback callback,
       std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_infos);
 
@@ -101,8 +101,7 @@
       SegmentInfoList segment_infos_from_default_model);
 
   // Default model providers.
-  std::map<OptimizationTarget, std::unique_ptr<ModelProvider>>
-      default_model_providers_;
+  std::map<SegmentId, std::unique_ptr<ModelProvider>> default_model_providers_;
   const raw_ptr<ModelProviderFactory> model_provider_factory_;
 
   base::WeakPtrFactory<DefaultModelManager> weak_ptr_factory_{this};
diff --git a/components/segmentation_platform/internal/execution/default_model_manager_unittest.cc b/components/segmentation_platform/internal/execution/default_model_manager_unittest.cc
index c7979cc4..32abfcc 100644
--- a/components/segmentation_platform/internal/execution/default_model_manager_unittest.cc
+++ b/components/segmentation_platform/internal/execution/default_model_manager_unittest.cc
@@ -18,18 +18,18 @@
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
 using base::test::RunOnceCallback;
-using optimization_guide::proto::OptimizationTarget;
 using testing::_;
 
 namespace segmentation_platform {
 
+using proto::SegmentId;
+
 class DefaultModelManagerTest : public testing::Test {
  public:
   DefaultModelManagerTest() : model_provider_factory_(&model_provider_data_) {}
   ~DefaultModelManagerTest() override = default;
 
-  MockModelProvider& FindHandler(
-      optimization_guide::proto::OptimizationTarget segment_id) {
+  MockModelProvider& FindHandler(proto::SegmentId segment_id) {
     return *(*model_provider_data_.default_model_providers.find(segment_id))
                 .second;
   }
@@ -52,14 +52,11 @@
 };
 
 TEST_F(DefaultModelManagerTest, BasicTest) {
-  const auto segment_1 =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
-  const auto segment_2 =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
-  const auto segment_3 =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE;
+  const auto segment_1 = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+  const auto segment_2 = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
+  const auto segment_3 = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE;
   const auto segment_4 =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES;
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES;
 
   // Set some model versions.
   const int model_version_db = 4;
diff --git a/components/segmentation_platform/internal/execution/mock_model_provider.cc b/components/segmentation_platform/internal/execution/mock_model_provider.cc
index 7c16a16..d9915c05 100644
--- a/components/segmentation_platform/internal/execution/mock_model_provider.cc
+++ b/components/segmentation_platform/internal/execution/mock_model_provider.cc
@@ -18,7 +18,7 @@
 
 // Stores the client callbacks to |data|.
 void StoreClientCallback(
-    optimization_guide::proto::OptimizationTarget segment_id,
+    proto::SegmentId segment_id,
     TestModelProviderFactory::Data* data,
     const ModelProvider::ModelUpdatedCallback& model_updated_callback) {
   data->model_providers_callbacks.emplace(
@@ -28,7 +28,7 @@
 }  // namespace
 
 MockModelProvider::MockModelProvider(
-    optimization_guide::proto::OptimizationTarget segment_id,
+    proto::SegmentId segment_id,
     base::RepeatingCallback<void(const ModelProvider::ModelUpdatedCallback&)>
         get_client_callback)
     : ModelProvider(segment_id), get_client_callback_(get_client_callback) {
@@ -44,7 +44,7 @@
 TestModelProviderFactory::Data::~Data() = default;
 
 std::unique_ptr<ModelProvider> TestModelProviderFactory::CreateProvider(
-    optimization_guide::proto::OptimizationTarget segment_id) {
+    proto::SegmentId segment_id) {
   auto provider = std::make_unique<MockModelProvider>(
       segment_id, base::BindRepeating(&StoreClientCallback, segment_id, data_));
   data_->model_providers.emplace(std::make_pair(segment_id, provider.get()));
@@ -52,7 +52,7 @@
 }
 
 std::unique_ptr<ModelProvider> TestModelProviderFactory::CreateDefaultProvider(
-    optimization_guide::proto::OptimizationTarget segment_id) {
+    proto::SegmentId segment_id) {
   if (!base::Contains(data_->segments_supporting_default_model, segment_id))
     return nullptr;
 
diff --git a/components/segmentation_platform/internal/execution/mock_model_provider.h b/components/segmentation_platform/internal/execution/mock_model_provider.h
index a1c679d4..b8dc7619 100644
--- a/components/segmentation_platform/internal/execution/mock_model_provider.h
+++ b/components/segmentation_platform/internal/execution/mock_model_provider.h
@@ -14,15 +14,15 @@
 #include "testing/gmock/include/gmock/gmock.h"
 #include "testing/gtest/include/gtest/gtest.h"
 
-using optimization_guide::proto::OptimizationTarget;
-
 namespace segmentation_platform {
 
+using proto::SegmentId;
+
 // Mock model provider for testing, to be used with TestModelProviderFactory.
 class MockModelProvider : public ModelProvider {
  public:
   MockModelProvider(
-      optimization_guide::proto::OptimizationTarget segment_id,
+      proto::SegmentId segment_id,
       base::RepeatingCallback<void(const ModelProvider::ModelUpdatedCallback&)>
           get_client_callback);
   ~MockModelProvider() override;
@@ -56,21 +56,18 @@
 
     // Map of targets to model providers, added when provider is created. The
     // list is not cleared when providers are destroyed.
-    std::map<optimization_guide::proto::OptimizationTarget, MockModelProvider*>
-        model_providers;
+    std::map<proto::SegmentId, MockModelProvider*> model_providers;
 
     // Map of targets to default model providers, added when provider is
     // created. The list is not cleared when providers are destroyed.
-    std::map<optimization_guide::proto::OptimizationTarget, MockModelProvider*>
-        default_model_providers;
+    std::map<proto::SegmentId, MockModelProvider*> default_model_providers;
 
     // Map from target to updated callback, recorded when InitAndFetchModel()
     // was called on any provider.
-    std::map<optimization_guide::proto::OptimizationTarget,
-             ModelProvider::ModelUpdatedCallback>
+    std::map<proto::SegmentId, ModelProvider::ModelUpdatedCallback>
         model_providers_callbacks;
 
-    std::vector<OptimizationTarget> segments_supporting_default_model;
+    std::vector<SegmentId> segments_supporting_default_model;
   };
 
   // Records requests to `data`. `data` is not owned, and the caller must ensure
@@ -81,10 +78,10 @@
   // ModelProviderFactory impl, that keeps track of the created provider and
   // callbacks in |data_|.
   std::unique_ptr<ModelProvider> CreateProvider(
-      optimization_guide::proto::OptimizationTarget segment_id) override;
+      proto::SegmentId segment_id) override;
 
   std::unique_ptr<ModelProvider> CreateDefaultProvider(
-      optimization_guide::proto::OptimizationTarget) override;
+      proto::SegmentId) override;
 
  private:
   raw_ptr<Data> data_;
diff --git a/components/segmentation_platform/internal/execution/model_execution_manager.h b/components/segmentation_platform/internal/execution/model_execution_manager.h
index db967a4..3e87570 100644
--- a/components/segmentation_platform/internal/execution/model_execution_manager.h
+++ b/components/segmentation_platform/internal/execution/model_execution_manager.h
@@ -6,7 +6,7 @@
 #define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_EXECUTION_MODEL_EXECUTION_MANAGER_H_
 
 #include "base/callback_forward.h"
-#include "components/optimization_guide/proto/models.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 
 namespace segmentation_platform {
 namespace proto {
@@ -31,8 +31,7 @@
   using SegmentationModelUpdatedCallback =
       base::RepeatingCallback<void(proto::SegmentInfo)>;
 
-  virtual ModelProvider* GetProvider(
-      optimization_guide::proto::OptimizationTarget segment_id) = 0;
+  virtual ModelProvider* GetProvider(proto::SegmentId segment_id) = 0;
 
  protected:
   ModelExecutionManager() = default;
diff --git a/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc b/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc
index 0c58d556..66ee95f 100644
--- a/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc
+++ b/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc
@@ -16,24 +16,23 @@
 #include "base/time/clock.h"
 #include "base/time/time.h"
 #include "base/trace_event/typed_macros.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/execution/model_execution_manager.h"
 #include "components/segmentation_platform/internal/metadata/metadata_utils.h"
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
 #include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
 #include "components/segmentation_platform/internal/stats.h"
 #include "components/segmentation_platform/public/model_provider.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
 namespace optimization_guide {
 class OptimizationGuideModelProvider;
-using proto::OptimizationTarget;
 }  // namespace optimization_guide
 
 namespace segmentation_platform {
 
 ModelExecutionManagerImpl::ModelExecutionManagerImpl(
-    const base::flat_set<OptimizationTarget>& segment_ids,
+    const base::flat_set<SegmentId>& segment_ids,
     ModelProviderFactory* model_provider_factory,
     base::Clock* clock,
     SegmentInfoDatabase* segment_database,
@@ -41,7 +40,7 @@
     : clock_(clock),
       segment_database_(segment_database),
       model_updated_callback_(model_updated_callback) {
-  for (OptimizationTarget segment_id : segment_ids) {
+  for (SegmentId segment_id : segment_ids) {
     std::unique_ptr<ModelProvider> provider =
         model_provider_factory->CreateProvider(segment_id);
     provider->InitAndFetchModel(base::BindRepeating(
@@ -54,21 +53,20 @@
 ModelExecutionManagerImpl::~ModelExecutionManagerImpl() = default;
 
 ModelProvider* ModelExecutionManagerImpl::GetProvider(
-    optimization_guide::proto::OptimizationTarget segment_id) {
+    proto::SegmentId segment_id) {
   auto it = model_providers_.find(segment_id);
   DCHECK(it != model_providers_.end());
   return it->second.get();
 }
 
 void ModelExecutionManagerImpl::OnSegmentationModelUpdated(
-    optimization_guide::proto::OptimizationTarget segment_id,
+    proto::SegmentId segment_id,
     proto::SegmentationModelMetadata metadata,
     int64_t model_version) {
   TRACE_EVENT("segmentation_platform",
               "ModelExecutionManagerImpl::OnSegmentationModelUpdated");
   stats::RecordModelDeliveryReceived(segment_id);
-  if (segment_id == optimization_guide::proto::OptimizationTarget::
-                        OPTIMIZATION_TARGET_UNKNOWN) {
+  if (segment_id == proto::SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
     return;
   }
 
@@ -91,7 +89,7 @@
 }
 
 void ModelExecutionManagerImpl::OnSegmentInfoFetchedForModelUpdate(
-    optimization_guide::proto::OptimizationTarget segment_id,
+    proto::SegmentId segment_id,
     proto::SegmentationModelMetadata metadata,
     int64_t model_version,
     absl::optional<proto::SegmentInfo> old_segment_info) {
diff --git a/components/segmentation_platform/internal/execution/model_execution_manager_impl.h b/components/segmentation_platform/internal/execution/model_execution_manager_impl.h
index 8969651..e3445835 100644
--- a/components/segmentation_platform/internal/execution/model_execution_manager_impl.h
+++ b/components/segmentation_platform/internal/execution/model_execution_manager_impl.h
@@ -13,9 +13,9 @@
 #include "base/containers/flat_set.h"
 #include "base/memory/raw_ptr.h"
 #include "base/memory/weak_ptr.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/segment_info_database.h"
 #include "components/segmentation_platform/internal/execution/model_execution_manager.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
 namespace base {
@@ -38,7 +38,7 @@
 class ModelExecutionManagerImpl : public ModelExecutionManager {
  public:
   ModelExecutionManagerImpl(
-      const base::flat_set<OptimizationTarget>& segment_ids,
+      const base::flat_set<SegmentId>& segment_ids,
       ModelProviderFactory* model_provider_factory,
       base::Clock* clock,
       SegmentInfoDatabase* segment_database,
@@ -51,8 +51,7 @@
       delete;
 
   // ModelExecutionManager override:
-  ModelProvider* GetProvider(
-      optimization_guide::proto::OptimizationTarget segment_id) override;
+  ModelProvider* GetProvider(proto::SegmentId segment_id) override;
 
  private:
   friend class SegmentationPlatformServiceImplTest;
@@ -61,10 +60,9 @@
   // Callback for whenever a SegmentationModelHandler is informed that the
   // underlying ML model file has been updated. If there is an available
   // model, this will be called at least once per session.
-  void OnSegmentationModelUpdated(
-      optimization_guide::proto::OptimizationTarget segment_id,
-      proto::SegmentationModelMetadata metadata,
-      int64_t model_version);
+  void OnSegmentationModelUpdated(proto::SegmentId segment_id,
+                                  proto::SegmentationModelMetadata metadata,
+                                  int64_t model_version);
 
   // Callback after fetching the current SegmentInfo from the
   // SegmentInfoDatabase. This is part of the flow for informing the
@@ -72,7 +70,7 @@
   // Merges the PredictionResult from the previously stored SegmentInfo with
   // the newly updated one, and stores the new version in the DB.
   void OnSegmentInfoFetchedForModelUpdate(
-      optimization_guide::proto::OptimizationTarget segment_id,
+      proto::SegmentId segment_id,
       proto::SegmentationModelMetadata metadata,
       int64_t model_version,
       absl::optional<proto::SegmentInfo> segment_info);
@@ -83,7 +81,7 @@
                                   bool success);
 
   // All the relevant handlers for each of the segments.
-  std::map<OptimizationTarget, std::unique_ptr<ModelProvider>> model_providers_;
+  std::map<SegmentId, std::unique_ptr<ModelProvider>> model_providers_;
 
   // Used to access the current time.
   raw_ptr<base::Clock> clock_;
diff --git a/components/segmentation_platform/internal/execution/model_execution_manager_impl_unittest.cc b/components/segmentation_platform/internal/execution/model_execution_manager_impl_unittest.cc
index 882d042..b25c150 100644
--- a/components/segmentation_platform/internal/execution/model_execution_manager_impl_unittest.cc
+++ b/components/segmentation_platform/internal/execution/model_execution_manager_impl_unittest.cc
@@ -18,7 +18,6 @@
 #include "base/test/simple_test_clock.h"
 #include "base/test/task_environment.h"
 #include "base/time/time.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/mock_signal_database.h"
 #include "components/segmentation_platform/internal/database/signal_database.h"
 #include "components/segmentation_platform/internal/database/test_segment_info_database.h"
@@ -31,6 +30,7 @@
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
 #include "components/segmentation_platform/internal/proto/types.pb.h"
 #include "components/segmentation_platform/public/model_provider.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "testing/gmock/include/gmock/gmock.h"
 #include "testing/gtest/include/gtest/gtest.h"
 
@@ -57,22 +57,22 @@
               (override));
   MOCK_METHOD(void,
               GetSegmentInfoForSegments,
-              (const std::vector<OptimizationTarget>& segment_ids,
+              (const std::vector<SegmentId>& segment_ids,
                MultipleSegmentInfoCallback callback),
               (override));
   MOCK_METHOD(void,
               GetSegmentInfo,
-              (OptimizationTarget segment_id, SegmentInfoCallback callback),
+              (SegmentId segment_id, SegmentInfoCallback callback),
               (override));
   MOCK_METHOD(void,
               UpdateSegment,
-              (OptimizationTarget segment_id,
+              (SegmentId segment_id,
                absl::optional<proto::SegmentInfo> segment_info,
                SuccessCallback callback),
               (override));
   MOCK_METHOD(void,
               SaveSegmentResult,
-              (OptimizationTarget segment_id,
+              (SegmentId segment_id,
                absl::optional<proto::PredictionResult> result,
                SuccessCallback callback),
               (override));
@@ -100,7 +100,7 @@
   }
 
   void CreateModelExecutionManager(
-      std::vector<OptimizationTarget> segment_ids,
+      std::vector<SegmentId> segment_ids,
       const ModelExecutionManager::SegmentationModelUpdatedCallback& callback) {
     model_execution_manager_ = std::make_unique<ModelExecutionManagerImpl>(
         segment_ids, &model_provider_factory_, &clock_, segment_database_.get(),
@@ -109,8 +109,7 @@
 
   void RunUntilIdle() { task_environment_.RunUntilIdle(); }
 
-  MockModelProvider& FindHandler(
-      optimization_guide::proto::OptimizationTarget segment_id) {
+  MockModelProvider& FindHandler(proto::SegmentId segment_id) {
     return *(*model_provider_data_.model_providers.find(segment_id)).second;
   }
 
@@ -137,8 +136,7 @@
   // Construct the ModelExecutionManager.
   base::MockCallback<ModelExecutionManager::SegmentationModelUpdatedCallback>
       callback;
-  auto segment_id =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+  auto segment_id = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
   CreateModelExecutionManager({segment_id}, callback.Get());
 
   // Create invalid metadata, which should be ignored.
@@ -157,8 +155,7 @@
 TEST_F(ModelExecutionManagerTest, OnSegmentationModelUpdatedNoOldMetadata) {
   base::MockCallback<ModelExecutionManager::SegmentationModelUpdatedCallback>
       callback;
-  auto segment_id =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+  auto segment_id = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
   CreateModelExecutionManager({segment_id}, callback.Get());
 
   proto::SegmentInfo segment_info;
@@ -195,8 +192,7 @@
        OnSegmentationModelUpdatedWithPreviousMetadataAndPredictionResult) {
   base::MockCallback<ModelExecutionManager::SegmentationModelUpdatedCallback>
       callback;
-  auto segment_id =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+  auto segment_id = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
   CreateModelExecutionManager({segment_id}, callback.Get());
 
   // Fill in old data in the SegmentInfo database.
diff --git a/components/segmentation_platform/internal/execution/model_executor_impl.cc b/components/segmentation_platform/internal/execution/model_executor_impl.cc
index bbb6c13..b752e72 100644
--- a/components/segmentation_platform/internal/execution/model_executor_impl.cc
+++ b/components/segmentation_platform/internal/execution/model_executor_impl.cc
@@ -9,19 +9,19 @@
 #include "base/time/clock.h"
 #include "base/time/time.h"
 #include "base/trace_event/typed_macros.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/execution/execution_request.h"
 #include "components/segmentation_platform/internal/execution/processing/feature_list_query_processor.h"
 #include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
 #include "components/segmentation_platform/internal/stats.h"
 #include "components/segmentation_platform/public/model_provider.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/perfetto/include/perfetto/tracing/track.h"
 
 namespace segmentation_platform {
 namespace {
-using optimization_guide::proto::OptimizationTarget;
 using processing::FeatureListQueryProcessor;
-}
+using proto::SegmentId;
+}  // namespace
 
 struct ModelExecutorImpl::ModelExecutionTraceEvent {
   ModelExecutionTraceEvent(const char* event_name,
@@ -55,7 +55,7 @@
   // https://crbug.com/1021571.
   std::unique_ptr<ModelExecutionTraceEvent> trace_event;
 
-  OptimizationTarget segment_id;
+  SegmentId segment_id;
   int64_t model_version = 0;
   raw_ptr<ModelProvider> model_provider = nullptr;
   bool record_metrics_for_default = false;
@@ -90,7 +90,7 @@
 void ModelExecutorImpl::ExecuteModel(
     std::unique_ptr<ExecutionRequest> request) {
   const proto::SegmentInfo& segment_info = *request->segment_info;
-  OptimizationTarget segment_id = segment_info.segment_id();
+  SegmentId segment_id = segment_info.segment_id();
 
   // Create an ExecutionState that will stay with this request until it has been
   // fully processed.
@@ -157,9 +157,7 @@
     for (unsigned i = 0; i < state->input_tensor.size(); ++i)
       log_input << " feature " << i << ": " << state->input_tensor[i];
     VLOG(1) << "Segmentation model input: " << log_input.str()
-            << " for segment "
-            << optimization_guide::proto::OptimizationTarget_Name(
-                   state->segment_id);
+            << " for segment " << proto::SegmentId_Name(state->segment_id);
   }
   const std::vector<float>& const_input_tensor = std::move(state->input_tensor);
   stats::RecordModelExecutionZeroValuePercent(state->segment_id,
@@ -182,8 +180,7 @@
       clock_->Now() - state->model_execution_start_time);
   if (result.has_value()) {
     VLOG(1) << "Segmentation model result: " << *result << " for segment "
-            << optimization_guide::proto::OptimizationTarget_Name(
-                   state->segment_id);
+            << proto::SegmentId_Name(state->segment_id);
     stats::RecordModelExecutionResult(state->segment_id, result.value());
     if (state->model_version && SegmentationUkmHelper::AllowedToUploadData(
                                     state->signal_storage_length, clock_)) {
@@ -195,8 +192,7 @@
                               ModelExecutionStatus::kSuccess);
   } else {
     VLOG(1) << "Segmentation model returned no result for segment "
-            << optimization_guide::proto::OptimizationTarget_Name(
-                   state->segment_id);
+            << proto::SegmentId_Name(state->segment_id);
     RunModelExecutionCallback(std::move(state), 0,
                               ModelExecutionStatus::kExecutionError);
   }
diff --git a/components/segmentation_platform/internal/execution/model_executor_impl_unittest.cc b/components/segmentation_platform/internal/execution/model_executor_impl_unittest.cc
index 450963b9..a5ce702 100644
--- a/components/segmentation_platform/internal/execution/model_executor_impl_unittest.cc
+++ b/components/segmentation_platform/internal/execution/model_executor_impl_unittest.cc
@@ -18,7 +18,6 @@
 #include "base/test/simple_test_clock.h"
 #include "base/test/task_environment.h"
 #include "base/time/time.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/mock_signal_database.h"
 #include "components/segmentation_platform/internal/database/signal_database.h"
 #include "components/segmentation_platform/internal/database/test_segment_info_database.h"
@@ -32,6 +31,7 @@
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
 #include "components/segmentation_platform/internal/proto/types.pb.h"
 #include "components/segmentation_platform/public/model_provider.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "testing/gmock/include/gmock/gmock.h"
 #include "testing/gtest/include/gtest/gtest.h"
 
@@ -45,8 +45,8 @@
 
 namespace segmentation_platform {
 
-const OptimizationTarget kSegmentId =
-    OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+const SegmentId kSegmentId =
+    SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
 
 class ModelExecutorTest : public testing::Test {
  public:
@@ -150,7 +150,7 @@
 
   // Initialize with required metadata.
   test::TestSegmentInfoDatabase metadata_writer;
-  const OptimizationTarget segment_id = kSegmentId;
+  const SegmentId segment_id = kSegmentId;
   metadata_writer.SetBucketDuration(segment_id, 3, proto::TimeUnit::HOUR);
   std::string user_action_name = "some_user_action";
   metadata_writer.AddUserActionFeature(segment_id, user_action_name, 3, 3,
diff --git a/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_handler.cc b/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_handler.cc
index 3fd50ee..44205cf 100644
--- a/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_handler.cc
+++ b/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_handler.cc
@@ -12,6 +12,7 @@
 #include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/execution/optimization_guide/segmentation_model_executor.h"
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
+#include "components/segmentation_platform/internal/segment_id_convertor.h"
 #include "components/segmentation_platform/internal/stats.h"
 
 namespace segmentation_platform {
@@ -20,7 +21,7 @@
     OptimizationGuideSegmentationModelHandler(
         optimization_guide::OptimizationGuideModelProvider* model_provider,
         scoped_refptr<base::SequencedTaskRunner> background_task_runner,
-        optimization_guide::proto::OptimizationTarget optimization_target,
+        optimization_guide::proto::OptimizationTarget segment_id,
         const ModelUpdatedCallback& model_updated_callback,
         absl::optional<optimization_guide::proto::Any>&& model_metadata)
     : optimization_guide::ModelHandler<float, const std::vector<float>&>(
@@ -28,11 +29,11 @@
           background_task_runner,
           std::make_unique<SegmentationModelExecutor>(),
           /*model_inference_timeout=*/absl::nullopt,
-          optimization_target,
+          segment_id,
           model_metadata),
       model_updated_callback_(model_updated_callback) {
   stats::RecordModelAvailability(
-      optimization_target,
+      OptimizationTargetToSegmentId(segment_id),
       stats::SegmentationModelAvailability::kModelHandlerCreated);
 }
 
@@ -40,12 +41,11 @@
     ~OptimizationGuideSegmentationModelHandler() = default;
 
 void OptimizationGuideSegmentationModelHandler::OnModelUpdated(
-    optimization_guide::proto::OptimizationTarget optimization_target,
+    optimization_guide::proto::OptimizationTarget segment_id,
     const optimization_guide::ModelInfo& model_info) {
   // First invoke parent to update internal status.
   optimization_guide::ModelHandler<
-      float, const std::vector<float>&>::OnModelUpdated(optimization_target,
-                                                        model_info);
+      float, const std::vector<float>&>::OnModelUpdated(segment_id, model_info);
   // The parent class should always set the model availability to true after
   // having received an updated model.
   DCHECK(ModelAvailable());
@@ -55,22 +55,23 @@
   absl::optional<proto::SegmentationModelMetadata> segmentation_model_metadata =
       ParsedSupportedFeaturesForLoadedModel<proto::SegmentationModelMetadata>();
   stats::RecordModelDeliveryHasMetadata(
-      optimization_target, segmentation_model_metadata.has_value());
+      OptimizationTargetToSegmentId(segment_id),
+      segmentation_model_metadata.has_value());
   if (!segmentation_model_metadata.has_value()) {
     // This is not expected to happen, since the optimization guide server is
     // expected to pass this along. Either something failed horribly on the way,
     // we failed to read the metadata, or the server side configuration is
     // wrong.
     stats::RecordModelAvailability(
-        optimization_target,
+        OptimizationTargetToSegmentId(segment_id),
         stats::SegmentationModelAvailability::kMetadataInvalid);
     return;
   }
   stats::RecordModelAvailability(
-      optimization_target,
+      OptimizationTargetToSegmentId(segment_id),
       stats::SegmentationModelAvailability::kModelAvailable);
 
-  model_updated_callback_.Run(optimization_target,
+  model_updated_callback_.Run(OptimizationTargetToSegmentId(segment_id),
                               std::move(*segmentation_model_metadata),
                               model_info.GetVersion());
 }
diff --git a/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_handler.h b/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_handler.h
index 4eb60ff..a619f53 100644
--- a/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_handler.h
+++ b/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_handler.h
@@ -10,6 +10,7 @@
 
 #include "components/optimization_guide/core/model_handler.h"
 #include "components/optimization_guide/proto/models.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 
 namespace optimization_guide {
 class OptimizationGuideModelProvider;
@@ -29,15 +30,13 @@
     : public optimization_guide::ModelHandler<float,
                                               const std::vector<float>&> {
  public:
-  using ModelUpdatedCallback = base::RepeatingCallback<void(
-      optimization_guide::proto::OptimizationTarget,
-      proto::SegmentationModelMetadata,
-      int64_t)>;
+  using ModelUpdatedCallback = base::RepeatingCallback<
+      void(proto::SegmentId, proto::SegmentationModelMetadata, int64_t)>;
 
   explicit OptimizationGuideSegmentationModelHandler(
       optimization_guide::OptimizationGuideModelProvider* model_provider,
       scoped_refptr<base::SequencedTaskRunner> background_task_runner,
-      optimization_guide::proto::OptimizationTarget optimization_target,
+      optimization_guide::proto::OptimizationTarget segment_id,
       const ModelUpdatedCallback& model_updated_callback,
       absl::optional<optimization_guide::proto::Any>&& model_metadata);
 
@@ -50,9 +49,8 @@
       const OptimizationGuideSegmentationModelHandler&) = delete;
 
   // optimization_guide::ModelHandler overrides.
-  void OnModelUpdated(
-      optimization_guide::proto::OptimizationTarget optimization_target,
-      const optimization_guide::ModelInfo& model_info) override;
+  void OnModelUpdated(optimization_guide::proto::OptimizationTarget segment_id,
+                      const optimization_guide::ModelInfo& model_info) override;
 
  private:
   // Callback to invoke whenever the model file has been updated. If there is
diff --git a/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_provider.cc b/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_provider.cc
index be5e59d..512614b 100644
--- a/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_provider.cc
+++ b/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_provider.cc
@@ -10,11 +10,12 @@
 #include "base/threading/sequenced_task_runner_handle.h"
 #include "components/optimization_guide/core/model_executor.h"
 #include "components/optimization_guide/proto/common_types.pb.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_handler.h"
 #include "components/segmentation_platform/internal/execution/optimization_guide/segmentation_model_executor.h"
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
+#include "components/segmentation_platform/internal/segment_id_convertor.h"
 #include "components/segmentation_platform/internal/stats.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 
 namespace segmentation_platform {
 
@@ -43,8 +44,8 @@
     OptimizationGuideSegmentationModelProvider(
         optimization_guide::OptimizationGuideModelProvider* model_provider,
         scoped_refptr<base::SequencedTaskRunner> background_task_runner,
-        optimization_guide::proto::OptimizationTarget optimization_target)
-    : ModelProvider(optimization_target),
+        proto::SegmentId segment_id)
+    : ModelProvider(segment_id),
       model_provider_(model_provider),
       background_task_runner_(background_task_runner) {}
 
@@ -55,8 +56,9 @@
     const ModelUpdatedCallback& model_updated_callback) {
   DCHECK(!model_handler_);
   model_handler_ = std::make_unique<OptimizationGuideSegmentationModelHandler>(
-      model_provider_, background_task_runner_, optimization_target_,
-      model_updated_callback, GetModelFetchConfig());
+      model_provider_, background_task_runner_,
+      SegmentIdToOptimizationTarget(segment_id_), model_updated_callback,
+      GetModelFetchConfig());
 }
 
 void OptimizationGuideSegmentationModelProvider::ExecuteModelWithInput(
diff --git a/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_provider.h b/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_provider.h
index 2e7bdee..c1bbe2a 100644
--- a/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_provider.h
+++ b/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_provider.h
@@ -9,8 +9,8 @@
 #include <vector>
 
 #include "components/optimization_guide/core/model_handler.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/public/model_provider.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 
 namespace optimization_guide {
 class OptimizationGuideSegmentationModelProvider;
@@ -27,7 +27,7 @@
   OptimizationGuideSegmentationModelProvider(
       optimization_guide::OptimizationGuideModelProvider* model_provider,
       scoped_refptr<base::SequencedTaskRunner> background_task_runner,
-      optimization_guide::proto::OptimizationTarget optimization_target);
+      proto::SegmentId segment_id);
 
   ~OptimizationGuideSegmentationModelProvider() override;
 
diff --git a/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_provider_unittest.cc b/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_provider_unittest.cc
index 5586081..13bec7e 100644
--- a/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_provider_unittest.cc
+++ b/components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_provider_unittest.cc
@@ -25,8 +25,7 @@
     registered_model_metadata_.insert_or_assign(target, model_metadata);
   }
 
-  bool DidRegisterForTarget(
-      optimization_guide::proto::OptimizationTarget target) const {
+  bool DidRegisterForTarget(proto::SegmentId target) const {
     auto it = registered_model_metadata_.find(target);
     if (it == registered_model_metadata_.end())
       return false;
@@ -67,7 +66,7 @@
   }
 
   std::unique_ptr<OptimizationGuideSegmentationModelProvider>
-  CreateModelProvider(optimization_guide::proto::OptimizationTarget target) {
+  CreateModelProvider(proto::SegmentId target) {
     return std::make_unique<OptimizationGuideSegmentationModelProvider>(
         model_observer_tracker_.get(), task_runner_, target);
   }
@@ -81,46 +80,41 @@
 
 TEST_F(OptimizationGuideSegmentationModelProviderTest, InitAndFetchModel) {
   std::unique_ptr<OptimizationGuideSegmentationModelProvider> provider =
-      CreateModelProvider(optimization_guide::proto::OptimizationTarget::
-                              OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+      CreateModelProvider(
+          proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
 
   // Not initialized yet.
   EXPECT_FALSE(model_observer_tracker_->DidRegisterForTarget(
-      optimization_guide::proto::OptimizationTarget::
-          OPTIMIZATION_TARGET_SEGMENTATION_SHARE));
+      proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE));
 
   // Init should register observer.
   provider->InitAndFetchModel(base::DoNothing());
   EXPECT_TRUE(model_observer_tracker_->DidRegisterForTarget(
-      optimization_guide::proto::OptimizationTarget::
-          OPTIMIZATION_TARGET_SEGMENTATION_SHARE));
+      proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE));
 
   // Different target does not register yet.
   EXPECT_FALSE(model_observer_tracker_->DidRegisterForTarget(
-      optimization_guide::proto::OptimizationTarget::
-          OPTIMIZATION_TARGET_SEGMENTATION_VOICE));
+      proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE));
 
   // Initialize voice provider.
   std::unique_ptr<OptimizationGuideSegmentationModelProvider> provider2 =
-      CreateModelProvider(optimization_guide::proto::OptimizationTarget::
-                              OPTIMIZATION_TARGET_SEGMENTATION_VOICE);
+      CreateModelProvider(
+          proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE);
   provider2->InitAndFetchModel(base::DoNothing());
 
   // 2 observers should be available:
   EXPECT_TRUE(model_observer_tracker_->DidRegisterForTarget(
-      optimization_guide::proto::OptimizationTarget::
-          OPTIMIZATION_TARGET_SEGMENTATION_SHARE));
+      proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE));
 
   EXPECT_TRUE(model_observer_tracker_->DidRegisterForTarget(
-      optimization_guide::proto::OptimizationTarget::
-          OPTIMIZATION_TARGET_SEGMENTATION_VOICE));
+      proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE));
 }
 
 TEST_F(OptimizationGuideSegmentationModelProviderTest,
        ExecuteModelWithoutFetch) {
   std::unique_ptr<OptimizationGuideSegmentationModelProvider> provider =
-      CreateModelProvider(optimization_guide::proto::OptimizationTarget::
-                              OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+      CreateModelProvider(
+          proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
 
   base::RunLoop run_loop;
   std::vector<float> input = {4, 5};
@@ -137,12 +131,11 @@
 
 TEST_F(OptimizationGuideSegmentationModelProviderTest, ExecuteModelWithFetch) {
   std::unique_ptr<OptimizationGuideSegmentationModelProvider> provider =
-      CreateModelProvider(optimization_guide::proto::OptimizationTarget::
-                              OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+      CreateModelProvider(
+          proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
   provider->InitAndFetchModel(base::DoNothing());
   EXPECT_TRUE(model_observer_tracker_->DidRegisterForTarget(
-      optimization_guide::proto::OptimizationTarget::
-          OPTIMIZATION_TARGET_SEGMENTATION_SHARE));
+      proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE));
 
   base::RunLoop run_loop;
   std::vector<float> input = {4, 5};
diff --git a/components/segmentation_platform/internal/execution/optimization_guide/segmentation_model_executor.h b/components/segmentation_platform/internal/execution/optimization_guide/segmentation_model_executor.h
index 02cbdcd..3077e9e4 100644
--- a/components/segmentation_platform/internal/execution/optimization_guide/segmentation_model_executor.h
+++ b/components/segmentation_platform/internal/execution/optimization_guide/segmentation_model_executor.h
@@ -9,7 +9,7 @@
 #include <vector>
 
 #include "components/optimization_guide/core/base_model_executor.h"
-#include "components/optimization_guide/proto/models.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 
 struct TfLiteTensor;
 
diff --git a/components/segmentation_platform/internal/execution/optimization_guide/segmentation_model_executor_unittest.cc b/components/segmentation_platform/internal/execution/optimization_guide/segmentation_model_executor_unittest.cc
index 87b8409..f943af4 100644
--- a/components/segmentation_platform/internal/execution/optimization_guide/segmentation_model_executor_unittest.cc
+++ b/components/segmentation_platform/internal/execution/optimization_guide/segmentation_model_executor_unittest.cc
@@ -19,24 +19,25 @@
 #include "components/optimization_guide/core/test_model_info_builder.h"
 #include "components/optimization_guide/core/test_optimization_guide_model_provider.h"
 #include "components/optimization_guide/proto/common_types.pb.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_handler.h"
 #include "components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_provider.h"
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
 #include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
+#include "components/segmentation_platform/internal/segment_id_convertor.h"
 #include "components/segmentation_platform/public/model_provider.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "testing/gmock/include/gmock/gmock.h"
 #include "testing/gtest/include/gtest/gtest.h"
 
 using testing::_;
 
+namespace segmentation_platform {
 namespace {
-const auto kOptimizationTarget = optimization_guide::proto::OptimizationTarget::
-    OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+const auto kSegmentId =
+    proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
 const int64_t kModelVersion = 123;
 }  // namespace
 
-namespace segmentation_platform {
 bool AreEqual(const proto::SegmentationModelMetadata& a,
               const proto::SegmentationModelMetadata& b) {
   // Serializing two protos and comparing them is unsafe, in particular if they
@@ -75,7 +76,7 @@
     opt_guide_model_provider_ =
         std::make_unique<OptimizationGuideSegmentationModelProvider>(
             optimization_guide_segmentation_model_provider_.get(),
-            task_environment_.GetMainThreadTaskRunner(), kOptimizationTarget);
+            task_environment_.GetMainThreadTaskRunner(), kSegmentId);
     opt_guide_model_provider_->InitAndFetchModel(callback);
   }
 
@@ -110,8 +111,8 @@
                               .SetModelFilePath(model_file_path_)
                               .SetVersion(kModelVersion)
                               .Build();
-    opt_guide_model_handler().OnModelUpdated(kOptimizationTarget,
-                                             *model_metadata);
+    opt_guide_model_handler().OnModelUpdated(
+        SegmentIdToOptimizationTarget(kSegmentId), *model_metadata);
     RunUntilIdle();
   }
 
@@ -141,11 +142,11 @@
   CreateModelExecutor(base::BindRepeating(
       [](base::RunLoop* run_loop,
          proto::SegmentationModelMetadata original_metadata,
-         optimization_guide::proto::OptimizationTarget optimization_target,
+         proto::SegmentId segment_id,
          proto::SegmentationModelMetadata actual_metadata,
          int64_t model_version) {
         // Verify that the callback is invoked with the correct data.
-        EXPECT_EQ(kOptimizationTarget, optimization_target);
+        EXPECT_EQ(kSegmentId, segment_id);
         EXPECT_TRUE(AreEqual(original_metadata, actual_metadata));
         EXPECT_EQ(kModelVersion, model_version);
         run_loop->Quit();
diff --git a/components/segmentation_platform/internal/execution/processing/custom_input_processor_unittest.cc b/components/segmentation_platform/internal/execution/processing/custom_input_processor_unittest.cc
index 99e4607..70aa3cc 100644
--- a/components/segmentation_platform/internal/execution/processing/custom_input_processor_unittest.cc
+++ b/components/segmentation_platform/internal/execution/processing/custom_input_processor_unittest.cc
@@ -10,10 +10,10 @@
 #include "base/run_loop.h"
 #include "base/test/simple_test_clock.h"
 #include "base/test/task_environment.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/ukm_types.h"
 #include "components/segmentation_platform/internal/execution/processing/feature_processor_state.h"
 #include "components/segmentation_platform/internal/execution/processing/query_processor.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "testing/gtest/include/gtest/gtest.h"
 
 namespace segmentation_platform::processing {
diff --git a/components/segmentation_platform/internal/execution/processing/feature_list_query_processor.cc b/components/segmentation_platform/internal/execution/processing/feature_list_query_processor.cc
index 0ae5f8c..3dd3185 100644
--- a/components/segmentation_platform/internal/execution/processing/feature_list_query_processor.cc
+++ b/components/segmentation_platform/internal/execution/processing/feature_list_query_processor.cc
@@ -38,7 +38,7 @@
 void FeatureListQueryProcessor::ProcessFeatureList(
     const proto::SegmentationModelMetadata& model_metadata,
     scoped_refptr<InputContext> input_context,
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     base::Time prediction_time,
     ProcessOption process_option,
     FeatureProcessorCallback callback) {
diff --git a/components/segmentation_platform/internal/execution/processing/feature_list_query_processor.h b/components/segmentation_platform/internal/execution/processing/feature_list_query_processor.h
index b748c3302..608a4fd 100644
--- a/components/segmentation_platform/internal/execution/processing/feature_list_query_processor.h
+++ b/components/segmentation_platform/internal/execution/processing/feature_list_query_processor.h
@@ -10,13 +10,13 @@
 #include <vector>
 
 #include "base/memory/weak_ptr.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/ukm_database.h"
 #include "components/segmentation_platform/internal/execution/processing/custom_input_processor.h"
 #include "components/segmentation_platform/internal/execution/processing/query_processor.h"
 #include "components/segmentation_platform/internal/execution/processing/uma_feature_processor.h"
 #include "components/segmentation_platform/internal/input_context.h"
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 
 namespace segmentation_platform {
 class StorageService;
@@ -26,7 +26,7 @@
 class FeatureAggregator;
 class FeatureProcessorState;
 
-using optimization_guide::proto::OptimizationTarget;
+using proto::SegmentId;
 
 // FeatureListQueryProcessor takes a segmentation model's metadata, processes
 // each feature in the metadata's feature list in order and computes an input
@@ -60,7 +60,7 @@
   virtual void ProcessFeatureList(
       const proto::SegmentationModelMetadata& model_metadata,
       scoped_refptr<InputContext> input_context,
-      OptimizationTarget segment_id,
+      SegmentId segment_id,
       base::Time prediction_time,
       ProcessOption process_option,
       FeatureProcessorCallback callback);
diff --git a/components/segmentation_platform/internal/execution/processing/feature_list_query_processor_unittest.cc b/components/segmentation_platform/internal/execution/processing/feature_list_query_processor_unittest.cc
index 9f6646c..e3fe6488 100644
--- a/components/segmentation_platform/internal/execution/processing/feature_list_query_processor_unittest.cc
+++ b/components/segmentation_platform/internal/execution/processing/feature_list_query_processor_unittest.cc
@@ -45,7 +45,7 @@
         std::make_unique<StorageService>(nullptr, std::move(moved_signal_db),
                                          nullptr, nullptr, &ukm_data_manager_);
     clock_.SetNow(base::Time::Now());
-    segment_id_ = OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+    segment_id_ = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
   }
 
   void TearDown() override {
@@ -199,7 +199,7 @@
 
   base::SimpleTestClock clock_;
   base::test::TaskEnvironment task_environment_;
-  OptimizationTarget segment_id_;
+  SegmentId segment_id_;
   proto::SegmentationModelMetadata model_metadata;
   MockUkmDataManager ukm_data_manager_;
   std::unique_ptr<StorageService> storage_service_;
diff --git a/components/segmentation_platform/internal/execution/processing/feature_processor_state.cc b/components/segmentation_platform/internal/execution/processing/feature_processor_state.cc
index 4710198..a8ea6da 100644
--- a/components/segmentation_platform/internal/execution/processing/feature_processor_state.cc
+++ b/components/segmentation_platform/internal/execution/processing/feature_processor_state.cc
@@ -32,12 +32,12 @@
 FeatureProcessorState::FeatureProcessorState()
     : prediction_time_(base::Time::Now()),
       bucket_duration_(base::TimeDelta()),
-      segment_id_(OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN) {}
+      segment_id_(SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {}
 
 FeatureProcessorState::FeatureProcessorState(
     base::Time prediction_time,
     base::TimeDelta bucket_duration,
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     std::deque<Data> data,
     scoped_refptr<InputContext> input_context,
     FeatureListQueryProcessor::FeatureProcessorCallback callback)
diff --git a/components/segmentation_platform/internal/execution/processing/feature_processor_state.h b/components/segmentation_platform/internal/execution/processing/feature_processor_state.h
index 129c225..68e0d96 100644
--- a/components/segmentation_platform/internal/execution/processing/feature_processor_state.h
+++ b/components/segmentation_platform/internal/execution/processing/feature_processor_state.h
@@ -11,16 +11,16 @@
 
 #include "base/time/clock.h"
 #include "base/time/time.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/ukm_types.h"
 #include "components/segmentation_platform/internal/execution/processing/feature_list_query_processor.h"
 #include "components/segmentation_platform/internal/input_context.h"
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
 namespace segmentation_platform::processing {
 
-using optimization_guide::proto::OptimizationTarget;
+using proto::SegmentId;
 
 // FeatureProcessorState is responsible for storing all necessary state during
 // the processing of a model's metadata.
@@ -44,7 +44,7 @@
   FeatureProcessorState(
       base::Time prediction_time,
       base::TimeDelta bucket_duration,
-      OptimizationTarget segment_id,
+      SegmentId segment_id,
       std::deque<Data> data,
       scoped_refptr<InputContext> input_context,
       FeatureListQueryProcessor::FeatureProcessorCallback callback);
@@ -59,7 +59,7 @@
 
   base::Time prediction_time() const { return prediction_time_; }
 
-  OptimizationTarget segment_id() const { return segment_id_; }
+  SegmentId segment_id() const { return segment_id_; }
 
   bool error() const { return error_; }
 
@@ -87,7 +87,7 @@
  private:
   const base::Time prediction_time_;
   const base::TimeDelta bucket_duration_;
-  const OptimizationTarget segment_id_;
+  const SegmentId segment_id_;
   std::deque<Data> data_;
   scoped_refptr<InputContext> input_context_;
 
diff --git a/components/segmentation_platform/internal/execution/processing/mock_feature_list_query_processor.h b/components/segmentation_platform/internal/execution/processing/mock_feature_list_query_processor.h
index 64c9da0..990bac2 100644
--- a/components/segmentation_platform/internal/execution/processing/mock_feature_list_query_processor.h
+++ b/components/segmentation_platform/internal/execution/processing/mock_feature_list_query_processor.h
@@ -20,7 +20,7 @@
               ProcessFeatureList,
               (const proto::SegmentationModelMetadata&,
                scoped_refptr<InputContext> input_context,
-               optimization_guide::proto::OptimizationTarget,
+               proto::SegmentId,
                base::Time,
                FeatureListQueryProcessor::ProcessOption,
                FeatureProcessorCallback),
diff --git a/components/segmentation_platform/internal/execution/processing/sql_feature_processor_unittest.cc b/components/segmentation_platform/internal/execution/processing/sql_feature_processor_unittest.cc
index 6616a6a..1c9f117 100644
--- a/components/segmentation_platform/internal/execution/processing/sql_feature_processor_unittest.cc
+++ b/components/segmentation_platform/internal/execution/processing/sql_feature_processor_unittest.cc
@@ -10,11 +10,11 @@
 #include "base/test/gmock_callback_support.h"
 #include "base/test/simple_test_clock.h"
 #include "base/test/task_environment.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/mock_ukm_database.h"
 #include "components/segmentation_platform/internal/database/ukm_types.h"
 #include "components/segmentation_platform/internal/execution/processing/feature_processor_state.h"
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "testing/gtest/include/gtest/gtest.h"
 
 using base::test::RunOnceCallback;
diff --git a/components/segmentation_platform/internal/execution/processing/uma_feature_processor.cc b/components/segmentation_platform/internal/execution/processing/uma_feature_processor.cc
index e3541c1..ad2ff48 100644
--- a/components/segmentation_platform/internal/execution/processing/uma_feature_processor.cc
+++ b/components/segmentation_platform/internal/execution/processing/uma_feature_processor.cc
@@ -23,7 +23,7 @@
     FeatureAggregator* feature_aggregator,
     const base::Time prediction_time,
     const base::TimeDelta bucket_duration,
-    const OptimizationTarget segment_id)
+    const SegmentId segment_id)
     : uma_features_(std::move(uma_features)),
       signal_database_(signal_database),
       feature_aggregator_(feature_aggregator),
diff --git a/components/segmentation_platform/internal/execution/processing/uma_feature_processor.h b/components/segmentation_platform/internal/execution/processing/uma_feature_processor.h
index cec66c01..664900b 100644
--- a/components/segmentation_platform/internal/execution/processing/uma_feature_processor.h
+++ b/components/segmentation_platform/internal/execution/processing/uma_feature_processor.h
@@ -10,11 +10,11 @@
 
 #include "base/memory/raw_ptr.h"
 #include "base/memory/weak_ptr.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/signal_database.h"
 #include "components/segmentation_platform/internal/execution/processing/feature_aggregator.h"
 #include "components/segmentation_platform/internal/execution/processing/query_processor.h"
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 
 namespace segmentation_platform::processing {
 class FeatureProcessorState;
@@ -30,7 +30,7 @@
       FeatureAggregator* feature_aggregator,
       const base::Time prediction_time,
       const base::TimeDelta bucket_duration,
-      const optimization_guide::proto::OptimizationTarget segment_id);
+      const proto::SegmentId segment_id);
 
   ~UmaFeatureProcessor() override;
 
@@ -69,7 +69,7 @@
   // Data needed for the processing of uma features.
   const base::Time prediction_time_;
   const base::TimeDelta bucket_duration_;
-  const optimization_guide::proto::OptimizationTarget segment_id_;
+  const proto::SegmentId segment_id_;
 
   // Temporary storage of the processing state object.
   // TODO(haileywang): Remove dependency to the state object once error check is
diff --git a/components/segmentation_platform/internal/metadata/metadata_utils.cc b/components/segmentation_platform/internal/metadata/metadata_utils.cc
index 7cf303ed..1863b27 100644
--- a/components/segmentation_platform/internal/metadata/metadata_utils.cc
+++ b/components/segmentation_platform/internal/metadata/metadata_utils.cc
@@ -11,13 +11,13 @@
 #include "base/strings/string_util.h"
 #include "base/strings/stringprintf.h"
 #include "base/time/time.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/signal_key.h"
 #include "components/segmentation_platform/internal/proto/aggregation.pb.h"
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
 #include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
 #include "components/segmentation_platform/internal/proto/types.pb.h"
 #include "components/segmentation_platform/public/features.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
 namespace segmentation_platform {
@@ -76,8 +76,7 @@
 }  // namespace
 
 ValidationResult ValidateSegmentInfo(const proto::SegmentInfo& segment_info) {
-  if (segment_info.segment_id() ==
-      optimization_guide::proto::OPTIMIZATION_TARGET_UNKNOWN) {
+  if (segment_info.segment_id() == proto::OPTIMIZATION_TARGET_UNKNOWN) {
     return ValidationResult::kSegmentIDNotFound;
   }
 
diff --git a/components/segmentation_platform/internal/metadata/metadata_utils.h b/components/segmentation_platform/internal/metadata/metadata_utils.h
index 85ff2bb..ff22148 100644
--- a/components/segmentation_platform/internal/metadata/metadata_utils.h
+++ b/components/segmentation_platform/internal/metadata/metadata_utils.h
@@ -6,17 +6,17 @@
 #define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_METADATA_METADATA_UTILS_H_
 
 #include "base/time/time.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/signal_key.h"
 #include "components/segmentation_platform/internal/execution/processing/query_processor.h"
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
 #include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
 #include "components/segmentation_platform/internal/proto/types.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
-using optimization_guide::proto::OptimizationTarget;
-
 namespace segmentation_platform {
+using proto::SegmentId;
+
 namespace metadata_utils {
 
 // Keep up to date with SegmentationPlatformValidationResult in
diff --git a/components/segmentation_platform/internal/metadata/metadata_utils_unittest.cc b/components/segmentation_platform/internal/metadata/metadata_utils_unittest.cc
index 962e8799..336e790 100644
--- a/components/segmentation_platform/internal/metadata/metadata_utils_unittest.cc
+++ b/components/segmentation_platform/internal/metadata/metadata_utils_unittest.cc
@@ -5,11 +5,11 @@
 #include "components/segmentation_platform/internal/metadata/metadata_utils.h"
 
 #include "base/metrics/metrics_hashes.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/ukm_types.h"
 #include "components/segmentation_platform/internal/execution/processing/query_processor.h"
 #include "components/segmentation_platform/internal/proto/aggregation.pb.h"
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "testing/gtest/include/gtest/gtest.h"
 
 namespace segmentation_platform {
@@ -41,8 +41,8 @@
   EXPECT_EQ(metadata_utils::ValidationResult::kSegmentIDNotFound,
             metadata_utils::ValidateSegmentInfo(segment_info));
 
-  segment_info.set_segment_id(optimization_guide::proto::OptimizationTarget::
-                                  OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
+  segment_info.set_segment_id(
+      proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
   EXPECT_EQ(metadata_utils::ValidationResult::kMetadataNotFound,
             metadata_utils::ValidateSegmentInfo(segment_info));
 
@@ -385,8 +385,8 @@
       metadata_utils::ValidationResult::kSegmentIDNotFound,
       metadata_utils::ValidateSegmentInfoMetadataAndFeatures(segment_info));
 
-  segment_info.set_segment_id(optimization_guide::proto::OptimizationTarget::
-                                  OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
+  segment_info.set_segment_id(
+      proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
   EXPECT_EQ(
       metadata_utils::ValidationResult::kMetadataNotFound,
       metadata_utils::ValidateSegmentInfoMetadataAndFeatures(segment_info));
diff --git a/components/segmentation_platform/internal/metric_filter_utils.cc b/components/segmentation_platform/internal/metric_filter_utils.cc
index fbc015f8..b6c18495 100644
--- a/components/segmentation_platform/internal/metric_filter_utils.cc
+++ b/components/segmentation_platform/internal/metric_filter_utils.cc
@@ -9,12 +9,11 @@
 
 namespace segmentation_platform::stats {
 namespace {
-using optimization_guide::proto::OptimizationTarget;
+using proto::SegmentId;
 
 }  // namespace
 
-std::string OptimizationTargetToSegmentGroupName(
-    OptimizationTarget segment_id) {
+std::string OptimizationTargetToSegmentGroupName(SegmentId segment_id) {
   return OptimizationTargetToHistogramVariant(segment_id);
 }
 
@@ -25,7 +24,7 @@
 
 std::string SegmentationKeyToSubsegmentTrialName(
     const std::string& segmentation_key,
-    optimization_guide::proto::OptimizationTarget segment_id) {
+    proto::SegmentId segment_id) {
   return base::StrCat({"Segmentation_",
                        SegmentationKeyToUmaName(segmentation_key), "_",
                        OptimizationTargetToHistogramVariant(segment_id)});
diff --git a/components/segmentation_platform/internal/metric_filter_utils.h b/components/segmentation_platform/internal/metric_filter_utils.h
index 4ef5f80..f027cf6 100644
--- a/components/segmentation_platform/internal/metric_filter_utils.h
+++ b/components/segmentation_platform/internal/metric_filter_utils.h
@@ -9,15 +9,14 @@
 #include <string>
 #include <vector>
 
-#include "components/optimization_guide/proto/models.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
 namespace segmentation_platform::stats {
 
 // Returns a name to be used in UMA dashboard as segment group for the given
 // `segment_id`.
-std::string OptimizationTargetToSegmentGroupName(
-    optimization_guide::proto::OptimizationTarget segment_id);
+std::string OptimizationTargetToSegmentGroupName(proto::SegmentId segment_id);
 
 // Returns a name to be used in UMA dashboard as segmentation type for the given
 // `segmentation_key`.
@@ -27,7 +26,7 @@
 // the given `segmentation_key` and `segment_id`.
 std::string SegmentationKeyToSubsegmentTrialName(
     const std::string& segmentation_key,
-    optimization_guide::proto::OptimizationTarget segment_id);
+    proto::SegmentId segment_id);
 
 }  // namespace segmentation_platform::stats
 
diff --git a/components/segmentation_platform/internal/proto/BUILD.gn b/components/segmentation_platform/internal/proto/BUILD.gn
index 09c721d1..2308e6cc 100644
--- a/components/segmentation_platform/internal/proto/BUILD.gn
+++ b/components/segmentation_platform/internal/proto/BUILD.gn
@@ -15,6 +15,5 @@
     "types.proto",
   ]
 
-  link_deps =
-      [ "//components/optimization_guide/proto:optimization_guide_proto" ]
+  link_deps = [ "//components/segmentation_platform/public/proto" ]
 }
diff --git a/components/segmentation_platform/internal/proto/model_prediction.proto b/components/segmentation_platform/internal/proto/model_prediction.proto
index c6ab5785..a8375a1 100644
--- a/components/segmentation_platform/internal/proto/model_prediction.proto
+++ b/components/segmentation_platform/internal/proto/model_prediction.proto
@@ -8,7 +8,7 @@
 package segmentation_platform.proto;
 
 import "components/segmentation_platform/internal/proto/model_metadata.proto";
-import "components/optimization_guide/proto/models.proto";
+import "components/segmentation_platform/public/proto/segmentation_platform.proto";
 
 // Result from the model evaluation for a given segment.
 message PredictionResult {
@@ -25,7 +25,7 @@
 // Next tag: 6
 message SegmentInfo {
   // Segment target.
-  optional optimization_guide.proto.OptimizationTarget segment_id = 1;
+  optional SegmentId segment_id = 1;
 
   // Cached copy of the segment metadata which is important in case the metadata
   // is temporarily not available in the future. It also contains the relevant
diff --git a/components/segmentation_platform/internal/scheduler/execution_service.cc b/components/segmentation_platform/internal/scheduler/execution_service.cc
index 6955be3..26b092c 100644
--- a/components/segmentation_platform/internal/scheduler/execution_service.cc
+++ b/components/segmentation_platform/internal/scheduler/execution_service.cc
@@ -40,7 +40,7 @@
     base::Clock* clock,
     ModelExecutionManager::SegmentationModelUpdatedCallback callback,
     scoped_refptr<base::SequencedTaskRunner> task_runner,
-    const base::flat_set<OptimizationTarget>& all_segment_ids,
+    const base::flat_set<SegmentId>& all_segment_ids,
     ModelProviderFactory* model_provider_factory,
     std::vector<ModelExecutionScheduler::Observer*>&& observers,
     const PlatformOptions& platform_options,
@@ -77,8 +77,7 @@
   model_execution_scheduler_->OnNewModelInfoReady(segment_info);
 }
 
-ModelProvider* ExecutionService::GetModelProvider(
-    OptimizationTarget segment_id) {
+ModelProvider* ExecutionService::GetModelProvider(SegmentId segment_id) {
   return model_execution_manager_->GetProvider(segment_id);
 }
 
@@ -105,7 +104,7 @@
 }
 
 void ExecutionService::OverwriteModelExecutionResult(
-    optimization_guide::proto::OptimizationTarget segment_id,
+    proto::SegmentId segment_id,
     const std::pair<float, ModelExecutionStatus>& result) {
   model_execution_scheduler_->OnModelExecutionCompleted(segment_id, result);
 }
diff --git a/components/segmentation_platform/internal/scheduler/execution_service.h b/components/segmentation_platform/internal/scheduler/execution_service.h
index ce2a8c38..58ec883 100644
--- a/components/segmentation_platform/internal/scheduler/execution_service.h
+++ b/components/segmentation_platform/internal/scheduler/execution_service.h
@@ -11,11 +11,11 @@
 #include "base/containers/flat_set.h"
 #include "base/task/sequenced_task_runner.h"
 #include "base/time/clock.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/execution/execution_request.h"
 #include "components/segmentation_platform/internal/execution/model_execution_manager_impl.h"
 #include "components/segmentation_platform/internal/input_context.h"
 #include "components/segmentation_platform/internal/scheduler/model_execution_scheduler.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 
 class PrefService;
 
@@ -53,7 +53,7 @@
       base::Clock* clock,
       ModelExecutionManager::SegmentationModelUpdatedCallback callback,
       scoped_refptr<base::SequencedTaskRunner> task_runner,
-      const base::flat_set<OptimizationTarget>& all_segment_ids,
+      const base::flat_set<SegmentId>& all_segment_ids,
       ModelProviderFactory* model_provider_factory,
       std::vector<ModelExecutionScheduler::Observer*>&& observers,
       const PlatformOptions& platform_options,
@@ -65,12 +65,12 @@
   void OnNewModelInfoReady(const proto::SegmentInfo& segment_info);
 
   // Gets the model provider for execution.
-  ModelProvider* GetModelProvider(OptimizationTarget segment_id);
+  ModelProvider* GetModelProvider(SegmentId segment_id);
 
   void RequestModelExecution(std::unique_ptr<ExecutionRequest> request);
 
   void OverwriteModelExecutionResult(
-      optimization_guide::proto::OptimizationTarget segment_id,
+      proto::SegmentId segment_id,
       const std::pair<float, ModelExecutionStatus>& result);
 
   // Refreshes model results for all eligible models.
diff --git a/components/segmentation_platform/internal/scheduler/model_execution_scheduler.h b/components/segmentation_platform/internal/scheduler/model_execution_scheduler.h
index 641b9806..a0ccfca3 100644
--- a/components/segmentation_platform/internal/scheduler/model_execution_scheduler.h
+++ b/components/segmentation_platform/internal/scheduler/model_execution_scheduler.h
@@ -5,16 +5,16 @@
 #ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SCHEDULER_MODEL_EXECUTION_SCHEDULER_H_
 #define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SCHEDULER_MODEL_EXECUTION_SCHEDULER_H_
 
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/execution/model_execution_status.h"
-
-using optimization_guide::proto::OptimizationTarget;
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 
 namespace segmentation_platform {
 namespace proto {
 class SegmentInfo;
 }  // namespace proto
 
+using proto::SegmentId;
+
 // Central class responsible for scheduling model execution. Determines which
 // models are eligible for execution based on various criteria e.g. cached
 // results, TTL etc. Invoked from multiple classes such as segment
@@ -25,7 +25,7 @@
   class Observer {
    public:
     // Called whenever a model execution completes.
-    virtual void OnModelExecutionCompleted(OptimizationTarget segment_id) = 0;
+    virtual void OnModelExecutionCompleted(SegmentId segment_id) = 0;
   };
 
   virtual ~ModelExecutionScheduler() = default;
@@ -54,7 +54,7 @@
   // TODO(shaktisahu): Do we want to store that failure reason in the DB
   // instead? We might treat different failures differently next time.
   virtual void OnModelExecutionCompleted(
-      OptimizationTarget segment_id,
+      SegmentId segment_id,
       const std::pair<float, ModelExecutionStatus>& result) = 0;
 };
 
diff --git a/components/segmentation_platform/internal/scheduler/model_execution_scheduler_impl.cc b/components/segmentation_platform/internal/scheduler/model_execution_scheduler_impl.cc
index 99b8c94..790cb79 100644
--- a/components/segmentation_platform/internal/scheduler/model_execution_scheduler_impl.cc
+++ b/components/segmentation_platform/internal/scheduler/model_execution_scheduler_impl.cc
@@ -26,7 +26,7 @@
     SignalStorageConfig* signal_storage_config,
     ModelExecutionManager* model_execution_manager,
     ModelExecutor* model_executor,
-    base::flat_set<optimization_guide::proto::OptimizationTarget> segment_ids,
+    base::flat_set<proto::SegmentId> segment_ids,
     base::Clock* clock,
     const PlatformOptions& platform_options)
     : observers_(observers),
@@ -59,8 +59,8 @@
 
 void ModelExecutionSchedulerImpl::RequestModelExecutionForEligibleSegments(
     bool expired_only) {
-  std::vector<OptimizationTarget> segment_ids(all_segment_ids_.begin(),
-                                              all_segment_ids_.end());
+  std::vector<SegmentId> segment_ids(all_segment_ids_.begin(),
+                                     all_segment_ids_.end());
   segment_database_->GetSegmentInfoForSegments(
       segment_ids,
       base::BindOnce(&ModelExecutionSchedulerImpl::FilterEligibleSegments,
@@ -69,7 +69,7 @@
 
 void ModelExecutionSchedulerImpl::RequestModelExecution(
     const proto::SegmentInfo& segment_info) {
-  OptimizationTarget segment_id = segment_info.segment_id();
+  SegmentId segment_id = segment_info.segment_id();
   CancelOutstandingExecutionRequests(segment_id);
   outstanding_requests_.insert(std::make_pair(
       segment_id,
@@ -86,7 +86,7 @@
 }
 
 void ModelExecutionSchedulerImpl::OnModelExecutionCompleted(
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     const std::pair<float, ModelExecutionStatus>& result) {
   // TODO(shaktisahu): Check ModelExecutionStatus and handle failure cases.
   // Should we save it to DB?
@@ -110,11 +110,11 @@
     std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> all_segments) {
   std::vector<const proto::SegmentInfo*> models_to_run;
   for (const auto& pair : *all_segments) {
-    OptimizationTarget segment_id = pair.first;
+    SegmentId segment_id = pair.first;
     const proto::SegmentInfo& segment_info = pair.second;
     if (!ShouldExecuteSegment(expired_only, segment_info)) {
       VLOG(1) << "Segmentation scheduler: Skipped executed segment "
-              << optimization_guide::proto::OptimizationTarget_Name(segment_id);
+              << proto::SegmentId_Name(segment_id);
       continue;
     }
 
@@ -168,7 +168,7 @@
 }
 
 void ModelExecutionSchedulerImpl::CancelOutstandingExecutionRequests(
-    OptimizationTarget segment_id) {
+    SegmentId segment_id) {
   const auto& iter = outstanding_requests_.find(segment_id);
   if (iter != outstanding_requests_.end()) {
     iter->second.Cancel();
@@ -176,7 +176,7 @@
   }
 }
 
-void ModelExecutionSchedulerImpl::OnResultSaved(OptimizationTarget segment_id,
+void ModelExecutionSchedulerImpl::OnResultSaved(SegmentId segment_id,
                                                 bool success) {
   stats::RecordModelExecutionSaveResult(segment_id, success);
   if (!success) {
diff --git a/components/segmentation_platform/internal/scheduler/model_execution_scheduler_impl.h b/components/segmentation_platform/internal/scheduler/model_execution_scheduler_impl.h
index 1b8954e..cbe2b1a 100644
--- a/components/segmentation_platform/internal/scheduler/model_execution_scheduler_impl.h
+++ b/components/segmentation_platform/internal/scheduler/model_execution_scheduler_impl.h
@@ -11,11 +11,11 @@
 #include "base/cancelable_callback.h"
 #include "base/containers/flat_set.h"
 #include "base/memory/weak_ptr.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/segment_info_database.h"
 #include "components/segmentation_platform/internal/execution/model_execution_status.h"
 #include "components/segmentation_platform/internal/execution/model_executor.h"
 #include "components/segmentation_platform/internal/platform_options.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 
 namespace base {
 class Clock;
@@ -32,15 +32,14 @@
 
 class ModelExecutionSchedulerImpl : public ModelExecutionScheduler {
  public:
-  ModelExecutionSchedulerImpl(
-      std::vector<Observer*>&& observers,
-      SegmentInfoDatabase* segment_database,
-      SignalStorageConfig* signal_storage_config,
-      ModelExecutionManager* model_execution_manager,
-      ModelExecutor* model_executor,
-      base::flat_set<optimization_guide::proto::OptimizationTarget> segment_ids,
-      base::Clock* clock,
-      const PlatformOptions& platform_options);
+  ModelExecutionSchedulerImpl(std::vector<Observer*>&& observers,
+                              SegmentInfoDatabase* segment_database,
+                              SignalStorageConfig* signal_storage_config,
+                              ModelExecutionManager* model_execution_manager,
+                              ModelExecutor* model_executor,
+                              base::flat_set<proto::SegmentId> segment_ids,
+                              base::Clock* clock,
+                              const PlatformOptions& platform_options);
   ~ModelExecutionSchedulerImpl() override;
 
   // Disallow copy/assign.
@@ -53,7 +52,7 @@
   void RequestModelExecutionForEligibleSegments(bool expired_only) override;
   void RequestModelExecution(const proto::SegmentInfo& segment_info) override;
   void OnModelExecutionCompleted(
-      OptimizationTarget segment_id,
+      SegmentId segment_id,
       const std::pair<float, ModelExecutionStatus>& score) override;
 
  private:
@@ -62,9 +61,9 @@
       std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> all_segments);
   bool ShouldExecuteSegment(bool expired_only,
                             const proto::SegmentInfo& segment_info);
-  void CancelOutstandingExecutionRequests(OptimizationTarget segment_id);
+  void CancelOutstandingExecutionRequests(SegmentId segment_id);
 
-  void OnResultSaved(OptimizationTarget segment_id, bool success);
+  void OnResultSaved(SegmentId segment_id, bool success);
 
   // Observers listening to model exeuction events. Required by the segment
   // selection pipeline.
@@ -81,8 +80,7 @@
   const raw_ptr<ModelExecutor> model_executor_;
 
   // The set of all known segments.
-  base::flat_set<optimization_guide::proto::OptimizationTarget>
-      all_segment_ids_;
+  base::flat_set<proto::SegmentId> all_segment_ids_;
 
   // The time provider.
   raw_ptr<base::Clock> clock_;
@@ -91,7 +89,7 @@
 
   // In-flight model execution requests. Will be killed if we get a model
   // update.
-  std::map<OptimizationTarget,
+  std::map<SegmentId,
            base::CancelableOnceCallback<
                ModelExecutor::ModelExecutionCallback::RunType>>
       outstanding_requests_;
diff --git a/components/segmentation_platform/internal/scheduler/model_execution_scheduler_unittest.cc b/components/segmentation_platform/internal/scheduler/model_execution_scheduler_unittest.cc
index a5560aa..d9cd92fe 100644
--- a/components/segmentation_platform/internal/scheduler/model_execution_scheduler_unittest.cc
+++ b/components/segmentation_platform/internal/scheduler/model_execution_scheduler_unittest.cc
@@ -30,23 +30,21 @@
 using CleanupItem = std::tuple<uint64_t, SignalType, base::Time>;
 
 namespace {
-constexpr auto kTestOptimizationTarget =
-    OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+constexpr auto kTestSegmentId =
+    SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
 constexpr auto kTestOptimizationTarget2 =
-    OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY;
+    SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY;
 }  // namespace
 
 class MockModelExecutionObserver : public ModelExecutionScheduler::Observer {
  public:
   MockModelExecutionObserver() = default;
-  MOCK_METHOD(void, OnModelExecutionCompleted, (OptimizationTarget));
+  MOCK_METHOD(void, OnModelExecutionCompleted, (SegmentId));
 };
 
 class MockModelExecutionManager : public ModelExecutionManager {
  public:
-  MOCK_METHOD(ModelProvider*,
-              GetProvider,
-              (optimization_guide::proto::OptimizationTarget segment_id));
+  MOCK_METHOD(ModelProvider*, GetProvider, (proto::SegmentId segment_id));
 };
 
 class MockModelExecutor : public ModelExecutor {
@@ -65,8 +63,8 @@
     std::vector<ModelExecutionScheduler::Observer*> observers = {&observer1_,
                                                                  &observer2_};
     segment_database_ = std::make_unique<test::TestSegmentInfoDatabase>();
-    base::flat_set<OptimizationTarget> segment_ids;
-    segment_ids.insert(kTestOptimizationTarget);
+    base::flat_set<SegmentId> segment_ids;
+    segment_ids.insert(kTestSegmentId);
     model_execution_scheduler_ = std::make_unique<ModelExecutionSchedulerImpl>(
         std::move(observers), segment_database_.get(), &signal_storage_config_,
         &model_execution_manager_, &model_executor_, segment_ids, &clock_,
@@ -89,13 +87,12 @@
 }
 
 TEST_F(ModelExecutionSchedulerTest, OnNewModelInfoReady) {
-  auto* segment_info =
-      segment_database_->FindOrCreateSegment(kTestOptimizationTarget);
-  segment_info->set_segment_id(kTestOptimizationTarget);
+  auto* segment_info = segment_database_->FindOrCreateSegment(kTestSegmentId);
+  segment_info->set_segment_id(kTestSegmentId);
   auto* metadata = segment_info->mutable_model_metadata();
   metadata->set_result_time_to_live(1);
   metadata->set_time_unit(proto::TimeUnit::DAY);
-  MockModelProvider provider(kTestOptimizationTarget, base::DoNothing());
+  MockModelProvider provider(kTestSegmentId, base::DoNothing());
 
   // If the metadata DOES NOT meet the signal requirement, we SHOULD NOT try to
   // execute the model.
@@ -106,10 +103,9 @@
 
   // If the metadata DOES meet the signal requirement, and we have no old,
   // PredictionResult we SHOULD try to execute the model.
-  EXPECT_CALL(model_execution_manager_, GetProvider(kTestOptimizationTarget))
+  EXPECT_CALL(model_execution_manager_, GetProvider(kTestSegmentId))
       .WillOnce(Return(&provider));
-  EXPECT_CALL(model_executor_,
-              ExecuteModel(IsForTarget(kTestOptimizationTarget)))
+  EXPECT_CALL(model_executor_, ExecuteModel(IsForTarget(kTestSegmentId)))
       .Times(1);
   EXPECT_CALL(signal_storage_config_, MeetsSignalCollectionRequirement(_, _))
       .WillOnce(Return(true));
@@ -141,26 +137,24 @@
   prediction_result->set_result(0.9);
   prediction_result->set_timestamp_us(
       just_expired_timestamp.ToDeltaSinceWindowsEpoch().InMicroseconds());
-  EXPECT_CALL(model_execution_manager_, GetProvider(kTestOptimizationTarget))
+  EXPECT_CALL(model_execution_manager_, GetProvider(kTestSegmentId))
       .WillOnce(Return(&provider));
-  EXPECT_CALL(model_executor_,
-              ExecuteModel(IsForTarget(kTestOptimizationTarget)))
+  EXPECT_CALL(model_executor_, ExecuteModel(IsForTarget(kTestSegmentId)))
       .Times(1);
   model_execution_scheduler_->OnNewModelInfoReady(*segment_info);
 }
 
 TEST_F(ModelExecutionSchedulerTest, RequestModelExecutionForEligibleSegments) {
-  MockModelProvider provider(kTestOptimizationTarget, base::DoNothing());
-  segment_database_->FindOrCreateSegment(kTestOptimizationTarget);
+  MockModelProvider provider(kTestSegmentId, base::DoNothing());
+  segment_database_->FindOrCreateSegment(kTestSegmentId);
   segment_database_->FindOrCreateSegment(kTestOptimizationTarget2);
 
   // TODO(shaktisahu): Add tests for expired segments, freshly computed segments
   // etc.
 
-  EXPECT_CALL(model_execution_manager_, GetProvider(kTestOptimizationTarget))
+  EXPECT_CALL(model_execution_manager_, GetProvider(kTestSegmentId))
       .WillOnce(Return(&provider));
-  EXPECT_CALL(model_executor_,
-              ExecuteModel(IsForTarget(kTestOptimizationTarget)))
+  EXPECT_CALL(model_executor_, ExecuteModel(IsForTarget(kTestSegmentId)))
       .Times(1);
   EXPECT_CALL(signal_storage_config_, MeetsSignalCollectionRequirement(_, _))
       .WillRepeatedly(Return(true));
@@ -174,21 +168,17 @@
 
 TEST_F(ModelExecutionSchedulerTest, OnModelExecutionCompleted) {
   proto::SegmentInfo* segment_info =
-      segment_database_->FindOrCreateSegment(kTestOptimizationTarget);
+      segment_database_->FindOrCreateSegment(kTestSegmentId);
 
   // TODO(shaktisahu): Add tests for model failure.
-  EXPECT_CALL(observer2_, OnModelExecutionCompleted(kTestOptimizationTarget))
-      .Times(1);
-  EXPECT_CALL(observer1_, OnModelExecutionCompleted(kTestOptimizationTarget))
-      .Times(1);
+  EXPECT_CALL(observer2_, OnModelExecutionCompleted(kTestSegmentId)).Times(1);
+  EXPECT_CALL(observer1_, OnModelExecutionCompleted(kTestSegmentId)).Times(1);
   float score = 0.4;
   model_execution_scheduler_->OnModelExecutionCompleted(
-      kTestOptimizationTarget,
-      std::make_pair(score, ModelExecutionStatus::kSuccess));
+      kTestSegmentId, std::make_pair(score, ModelExecutionStatus::kSuccess));
 
   // Verify that the results are written to the DB.
-  segment_info =
-      segment_database_->FindOrCreateSegment(kTestOptimizationTarget);
+  segment_info = segment_database_->FindOrCreateSegment(kTestSegmentId);
   ASSERT_TRUE(segment_info->has_prediction_result());
   ASSERT_EQ(score, segment_info->prediction_result().result());
 }
diff --git a/components/segmentation_platform/internal/segment_id_convertor.cc b/components/segmentation_platform/internal/segment_id_convertor.cc
new file mode 100644
index 0000000..a60bb91
--- /dev/null
+++ b/components/segmentation_platform/internal/segment_id_convertor.cc
@@ -0,0 +1,19 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/segmentation_platform/internal/segment_id_convertor.h"
+
+namespace segmentation_platform {
+
+optimization_guide::proto::OptimizationTarget SegmentIdToOptimizationTarget(
+    proto::SegmentId segment_id) {
+  return static_cast<optimization_guide::proto::OptimizationTarget>(segment_id);
+}
+
+proto::SegmentId OptimizationTargetToSegmentId(
+    optimization_guide::proto::OptimizationTarget segment_id) {
+  return static_cast<proto::SegmentId>(segment_id);
+}
+
+}  // namespace segmentation_platform
diff --git a/components/segmentation_platform/internal/segment_id_convertor.h b/components/segmentation_platform/internal/segment_id_convertor.h
new file mode 100644
index 0000000..e0e70e6
--- /dev/null
+++ b/components/segmentation_platform/internal/segment_id_convertor.h
@@ -0,0 +1,23 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SEGMENT_ID_CONVERTOR_H_
+#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SEGMENT_ID_CONVERTOR_H_
+
+#include "components/optimization_guide/proto/models.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
+
+namespace segmentation_platform {
+
+// Conversion functions between OptimizationTarget and SegmentId.
+optimization_guide::proto::OptimizationTarget SegmentIdToOptimizationTarget(
+    proto::SegmentId segment_id);
+
+// Conversion functions between OptimizationTarget and SegmentId.
+proto::SegmentId OptimizationTargetToSegmentId(
+    optimization_guide::proto::OptimizationTarget segment_id);
+
+}  // namespace segmentation_platform
+
+#endif  // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SEGMENT_ID_CONVERTOR_H_
diff --git a/components/segmentation_platform/internal/segmentation_platform_service_impl.cc b/components/segmentation_platform/internal/segmentation_platform_service_impl.cc
index 44ceb10..09022c0 100644
--- a/components/segmentation_platform/internal/segmentation_platform_service_impl.cc
+++ b/components/segmentation_platform/internal/segmentation_platform_service_impl.cc
@@ -30,15 +30,14 @@
 #include "components/segmentation_platform/public/field_trial_register.h"
 #include "components/segmentation_platform/public/model_provider.h"
 
-using optimization_guide::proto::OptimizationTarget;
-
 namespace segmentation_platform {
-
 namespace {
 
-base::flat_set<OptimizationTarget> GetAllSegmentIds(
+using proto::SegmentId;
+
+base::flat_set<SegmentId> GetAllSegmentIds(
     const std::vector<std::unique_ptr<Config>>& configs) {
-  base::flat_set<OptimizationTarget> all_segment_ids;
+  base::flat_set<SegmentId> all_segment_ids;
   for (const auto& config : configs) {
     for (const auto& segment_id : config->segment_ids)
       all_segment_ids.insert(segment_id);
@@ -83,8 +82,8 @@
         model_provider_factory_.get());
   }
 
-  std::vector<OptimizationTarget> segment_id_vec(all_segment_ids_.begin(),
-                                                 all_segment_ids_.end());
+  std::vector<SegmentId> segment_id_vec(all_segment_ids_.begin(),
+                                        all_segment_ids_.end());
 
   // Construct signal processors.
   signal_handler_.Initialize(
diff --git a/components/segmentation_platform/internal/segmentation_platform_service_impl.h b/components/segmentation_platform/internal/segmentation_platform_service_impl.h
index 84d487c5..b7aa877 100644
--- a/components/segmentation_platform/internal/segmentation_platform_service_impl.h
+++ b/components/segmentation_platform/internal/segmentation_platform_service_impl.h
@@ -14,13 +14,13 @@
 #include "base/memory/raw_ptr.h"
 #include "base/memory/scoped_refptr.h"
 #include "base/memory/weak_ptr.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/storage_service.h"
 #include "components/segmentation_platform/internal/execution/model_execution_manager.h"
 #include "components/segmentation_platform/internal/platform_options.h"
 #include "components/segmentation_platform/internal/scheduler/execution_service.h"
 #include "components/segmentation_platform/internal/service_proxy_impl.h"
 #include "components/segmentation_platform/internal/signals/signal_handler.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "components/segmentation_platform/public/segmentation_platform_service.h"
 
 namespace base {
@@ -123,8 +123,7 @@
 
   // Config.
   std::vector<std::unique_ptr<Config>> configs_;
-  base::flat_set<optimization_guide::proto::OptimizationTarget>
-      all_segment_ids_;
+  base::flat_set<proto::SegmentId> all_segment_ids_;
   std::unique_ptr<FieldTrialRegister> field_trial_register_;
 
   std::unique_ptr<StorageService> storage_service_;
diff --git a/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc b/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc
index e9451cf..b8b733f 100644
--- a/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc
+++ b/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc
@@ -19,6 +19,7 @@
 #include "components/segmentation_platform/internal/database/mock_ukm_database.h"
 #include "components/segmentation_platform/internal/dummy_ukm_data_manager.h"
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
+#include "components/segmentation_platform/internal/segment_id_convertor.h"
 #include "components/segmentation_platform/internal/segmentation_platform_service_test_base.h"
 #include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
 #include "components/segmentation_platform/internal/signals/ukm_observer.h"
@@ -101,12 +102,11 @@
   void AssertSelectedSegment(
       const std::string& segmentation_key,
       bool is_ready,
-      OptimizationTarget expected =
-          OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN) {
+      SegmentId expected = SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
     SegmentSelectionResult result;
     result.is_ready = is_ready;
     if (is_ready)
-      result.segment = expected;
+      result.segment = SegmentIdToOptimizationTarget(expected);
     base::RunLoop loop;
     segmentation_platform_service_impl_->GetSelectedSegment(
         segmentation_key,
@@ -119,12 +119,11 @@
   void AssertCachedSegment(
       const std::string& segmentation_key,
       bool is_ready,
-      OptimizationTarget expected =
-          OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN) {
+      SegmentId expected = SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
     SegmentSelectionResult result;
     result.is_ready = is_ready;
     if (is_ready)
-      result.segment = expected;
+      result.segment = SegmentIdToOptimizationTarget(expected);
     ASSERT_EQ(result,
               segmentation_platform_service_impl_->GetCachedSegmentResult(
                   segmentation_key));
@@ -166,12 +165,12 @@
     // from the database, and then write the merged result of the old and new to
     // the database.
     ASSERT_TRUE(model_provider_data_.model_providers_callbacks.count(
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE));
+        SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE));
     model_provider_data_
         .model_providers_callbacks
-            [OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE]
-        .Run(OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
-             metadata, kModelVersion);
+            [SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE]
+        .Run(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE, metadata,
+             kModelVersion);
     segment_db_->GetCallback(true);
     segment_db_->UpdateCallback(true);
 
@@ -185,14 +184,12 @@
         histogram_tester.GetBucketCount(
             "SegmentationPlatform.Signals.ListeningCount.HistogramValue", 1));
 
-    AssertSelectedSegment(
-        kTestSegmentationKey1, true,
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+    AssertSelectedSegment(kTestSegmentationKey1, true,
+                          SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
     AssertSelectedSegment(kTestSegmentationKey2, false);
     AssertSelectedSegment(kTestSegmentationKey3, false);
-    AssertCachedSegment(
-        kTestSegmentationKey1, true,
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+    AssertCachedSegment(kTestSegmentationKey1, true,
+                        SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
     AssertCachedSegment(kTestSegmentationKey2, false);
     AssertCachedSegment(kTestSegmentationKey3, false);
 
@@ -202,12 +199,12 @@
     segment_db_->LoadCallback(true);
 
     ASSERT_TRUE(model_provider_data_.model_providers_callbacks.count(
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE));
+        SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE));
     model_provider_data_
         .model_providers_callbacks
-            [OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE]
-        .Run(OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE,
-             metadata, kModelVersion);
+            [SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE]
+        .Run(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE, metadata,
+             kModelVersion);
     segment_db_->GetCallback(true);
     segment_db_->UpdateCallback(true);
 
@@ -232,14 +229,12 @@
     task_environment_.FastForwardBy(base::Hours(1));
     segment_db_->LoadCallback(true);
 
-    AssertSelectedSegment(
-        kTestSegmentationKey1, true,
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+    AssertSelectedSegment(kTestSegmentationKey1, true,
+                          SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
     AssertSelectedSegment(kTestSegmentationKey2, false);
     AssertSelectedSegment(kTestSegmentationKey3, false);
-    AssertCachedSegment(
-        kTestSegmentationKey1, true,
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+    AssertCachedSegment(kTestSegmentationKey1, true,
+                        SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
     AssertCachedSegment(kTestSegmentationKey2, false);
     AssertCachedSegment(kTestSegmentationKey3, false);
   }
@@ -299,14 +294,12 @@
 
     base::Value segmentation_result(base::Value::Type::DICTIONARY);
     segmentation_result.SetIntKey(
-        "segment_id",
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+        "segment_id", SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
     dictionary->SetKey(kTestSegmentationKey1, std::move(segmentation_result));
 
     base::Value segmentation_result2(base::Value::Type::DICTIONARY);
     segmentation_result2.SetIntKey(
-        "segment_id",
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE);
+        "segment_id", SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE);
     dictionary->SetKey(kTestSegmentationKey2, std::move(segmentation_result2));
   }
 };
@@ -324,19 +317,15 @@
   // querying segment db.
   segment_db_->LoadCallback(true);
 
-  AssertSelectedSegment(
-      kTestSegmentationKey1, true,
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
-  AssertSelectedSegment(
-      kTestSegmentationKey2, true,
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE);
+  AssertSelectedSegment(kTestSegmentationKey1, true,
+                        SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+  AssertSelectedSegment(kTestSegmentationKey2, true,
+                        SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE);
   AssertSelectedSegment(kTestSegmentationKey3, false);
-  AssertCachedSegment(
-      kTestSegmentationKey1, true,
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
-  AssertCachedSegment(
-      kTestSegmentationKey2, true,
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE);
+  AssertCachedSegment(kTestSegmentationKey1, true,
+                      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+  AssertCachedSegment(kTestSegmentationKey2, true,
+                      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE);
   AssertCachedSegment(kTestSegmentationKey3, false);
 }
 
diff --git a/components/segmentation_platform/internal/segmentation_platform_service_test_base.cc b/components/segmentation_platform/internal/segmentation_platform_service_test_base.cc
index 18ea67b..8fa92fb 100644
--- a/components/segmentation_platform/internal/segmentation_platform_service_test_base.cc
+++ b/components/segmentation_platform/internal/segmentation_platform_service_test_base.cc
@@ -27,7 +27,7 @@
                     base::StringPiece group_name));
   MOCK_METHOD3(RegisterSubsegmentFieldTrialIfNeeded,
                void(base::StringPiece trial_name,
-                    optimization_guide::proto::OptimizationTarget segment_id,
+                    proto::SegmentId segment_id,
                     int subsegment_rank));
 };
 
@@ -37,26 +37,23 @@
     std::unique_ptr<Config> config = std::make_unique<Config>();
     config->segmentation_key = kTestSegmentationKey1;
     config->segment_selection_ttl = base::Days(28);
-    config->segment_ids = {
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE};
+    config->segment_ids = {SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
+                           SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE};
     configs.push_back(std::move(config));
   }
   {
     std::unique_ptr<Config> config = std::make_unique<Config>();
     config->segmentation_key = kTestSegmentationKey2;
     config->segment_selection_ttl = base::Days(10);
-    config->segment_ids = {
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE};
+    config->segment_ids = {SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
+                           SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE};
     configs.push_back(std::move(config));
   }
   {
     std::unique_ptr<Config> config = std::make_unique<Config>();
     config->segmentation_key = kTestSegmentationKey3;
     config->segment_selection_ttl = base::Days(14);
-    config->segment_ids = {
-        OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB};
+    config->segment_ids = {SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB};
     configs.push_back(std::move(config));
   }
   {
@@ -104,7 +101,7 @@
   SetUpPrefs();
 
   std::vector<std::unique_ptr<Config>> configs = CreateTestConfigs();
-  base::flat_set<OptimizationTarget> all_segment_ids;
+  base::flat_set<SegmentId> all_segment_ids;
   for (const auto& config : configs) {
     for (const auto& segment_id : config->segment_ids)
       all_segment_ids.insert(segment_id);
@@ -141,7 +138,7 @@
 
   base::Value segmentation_result(base::Value::Type::DICTIONARY);
   segmentation_result.SetIntKey(
-      "segment_id", OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+      "segment_id", SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
   dictionary->SetKey(kTestSegmentationKey1, std::move(segmentation_result));
 }
 
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper.cc b/components/segmentation_platform/internal/segmentation_ukm_helper.cc
index 24bc080..38d0751 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper.cc
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper.cc
@@ -21,6 +21,7 @@
 #define CALL_MEMBER_FN(obj, func) ((obj).*(func))
 #define ARRAY_SIZE(x) (sizeof(x) / sizeof(x)[0])
 
+using segmentation_platform::proto::SegmentId;
 using ukm::builders::Segmentation_ModelExecution;
 
 namespace {
@@ -67,7 +68,7 @@
     &Segmentation_ModelExecution::SetActualResult5,
     &Segmentation_ModelExecution::SetActualResult6};
 
-base::flat_set<OptimizationTarget> GetSegmentIdsAllowedForReporting() {
+base::flat_set<SegmentId> GetSegmentIdsAllowedForReporting() {
   std::vector<std::string> segment_ids = base::SplitString(
       base::GetFieldTrialParamValueByFeature(
           segmentation_platform::features::
@@ -75,11 +76,11 @@
           segmentation_platform::kSegmentIdsAllowedForReportingKey),
       ",;", base::WhitespaceHandling::TRIM_WHITESPACE,
       base::SplitResult::SPLIT_WANT_NONEMPTY);
-  base::flat_set<OptimizationTarget> result;
+  base::flat_set<SegmentId> result;
   for (const auto& id : segment_ids) {
     int segment_id;
     if (base::StringToInt(id, &segment_id))
-      result.emplace(static_cast<OptimizationTarget>(segment_id));
+      result.emplace(static_cast<SegmentId>(segment_id));
   }
   return result;
 }
@@ -105,7 +106,7 @@
 }
 
 ukm::SourceId SegmentationUkmHelper::RecordModelExecutionResult(
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     int64_t model_version,
     const std::vector<float>& input_tensor,
     float result) {
@@ -124,7 +125,7 @@
 }
 
 ukm::SourceId SegmentationUkmHelper::RecordTrainingData(
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     int64_t model_version,
     const std::vector<float>& input_tensor,
     const std::vector<float>& outputs,
@@ -158,7 +159,7 @@
 
 bool SegmentationUkmHelper::AddInputsToUkm(
     ukm::builders::Segmentation_ModelExecution* ukm_builder,
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     int64_t model_version,
     const std::vector<float>& input_tensor) {
   if (!allowed_segment_ids_.contains(static_cast<int>(segment_id)))
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper.h b/components/segmentation_platform/internal/segmentation_ukm_helper.h
index 4ceb6e5..80b2021 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper.h
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper.h
@@ -8,13 +8,11 @@
 #include "base/containers/flat_set.h"
 #include "base/no_destructor.h"
 #include "base/time/time.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "services/metrics/public/cpp/ukm_source_id.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
-using optimization_guide::proto::OptimizationTarget;
-
 namespace base {
 class Clock;
 }
@@ -24,6 +22,8 @@
 }  // namespace ukm::builders
 
 namespace segmentation_platform {
+
+using proto::SegmentId;
 struct SelectedSegment;
 
 // A helper class to record segmentation model execution results in UKM.
@@ -36,7 +36,7 @@
   // Record segmentation model information and input/output after the
   // executing the model, and return the UKM source ID.
   ukm::SourceId RecordModelExecutionResult(
-      OptimizationTarget segment_id,
+      SegmentId segment_id,
       int64_t model_version,
       const std::vector<float>& input_tensor,
       float result);
@@ -51,7 +51,7 @@
   // tied to the ML model.
   // Return the UKM source ID.
   ukm::SourceId RecordTrainingData(
-      OptimizationTarget segment_id,
+      SegmentId segment_id,
       int64_t model_version,
       const std::vector<float>& input_tensors,
       const std::vector<float>& outputs,
@@ -68,13 +68,13 @@
                                   base::Clock* clock);
 
   // Gets a set of segment IDs that are allowed to upload metrics.
-  const base::flat_set<OptimizationTarget>& allowed_segment_ids() {
+  const base::flat_set<SegmentId>& allowed_segment_ids() {
     return allowed_segment_ids_;
   }
 
  private:
   bool AddInputsToUkm(ukm::builders::Segmentation_ModelExecution* ukm_builder,
-                      OptimizationTarget segment_id,
+                      SegmentId segment_id,
                       int64_t model_version,
                       const std::vector<float>& input_tensor);
 
@@ -89,7 +89,7 @@
 
   void Initialize();
 
-  base::flat_set<OptimizationTarget> allowed_segment_ids_;
+  base::flat_set<SegmentId> allowed_segment_ids_;
 };
 
 }  // namespace segmentation_platform
diff --git a/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc b/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
index 8a3d59ff..df758d0 100644
--- a/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
+++ b/components/segmentation_platform/internal/segmentation_ukm_helper_unittest.cc
@@ -24,6 +24,8 @@
 
 using Segmentation_ModelExecution = ukm::builders::Segmentation_ModelExecution;
 
+namespace segmentation_platform {
+
 namespace {
 
 // Round errors allowed during conversion.
@@ -42,17 +44,14 @@
       kRoundingError);
 }
 
-absl::optional<segmentation_platform::proto::PredictionResult>
-GetPredictionResult() {
-  segmentation_platform::proto::PredictionResult result;
+absl::optional<proto::PredictionResult> GetPredictionResult() {
+  proto::PredictionResult result;
   result.set_result(0.5);
   return result;
 }
 
 }  // namespace
 
-namespace segmentation_platform {
-
 class SegmentationUkmHelperTest : public testing::Test {
  public:
   SegmentationUkmHelperTest() = default;
@@ -109,26 +108,24 @@
   InitializeAllowedSegmentIds("4");
   std::vector<float> input_tensors = {0.1, 0.7, 0.8, 0.5};
   SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
-      optimization_guide::proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101,
-      input_tensors, 0.6);
-  ExpectUkmMetrics(
-      Segmentation_ModelExecution::kEntryName,
-      {Segmentation_ModelExecution::kOptimizationTargetName,
-       Segmentation_ModelExecution::kModelVersionName,
-       Segmentation_ModelExecution::kInput0Name,
-       Segmentation_ModelExecution::kInput1Name,
-       Segmentation_ModelExecution::kInput2Name,
-       Segmentation_ModelExecution::kInput3Name,
-       Segmentation_ModelExecution::kPredictionResultName},
-      {
-          optimization_guide::proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-          101,
-          SegmentationUkmHelper::FloatToInt64(0.1),
-          SegmentationUkmHelper::FloatToInt64(0.7),
-          SegmentationUkmHelper::FloatToInt64(0.8),
-          SegmentationUkmHelper::FloatToInt64(0.5),
-          SegmentationUkmHelper::FloatToInt64(0.6),
-      });
+      proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors, 0.6);
+  ExpectUkmMetrics(Segmentation_ModelExecution::kEntryName,
+                   {Segmentation_ModelExecution::kOptimizationTargetName,
+                    Segmentation_ModelExecution::kModelVersionName,
+                    Segmentation_ModelExecution::kInput0Name,
+                    Segmentation_ModelExecution::kInput1Name,
+                    Segmentation_ModelExecution::kInput2Name,
+                    Segmentation_ModelExecution::kInput3Name,
+                    Segmentation_ModelExecution::kPredictionResultName},
+                   {
+                       proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
+                       101,
+                       SegmentationUkmHelper::FloatToInt64(0.1),
+                       SegmentationUkmHelper::FloatToInt64(0.7),
+                       SegmentationUkmHelper::FloatToInt64(0.8),
+                       SegmentationUkmHelper::FloatToInt64(0.5),
+                       SegmentationUkmHelper::FloatToInt64(0.6),
+                   });
 }
 
 // Tests that the training data collection recording works properly.
@@ -140,32 +137,30 @@
   std::vector<int> output_indexes = {2, 3};
 
   SelectedSegment selected_segment(
-      optimization_guide::proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
+      proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
   selected_segment.selection_time = base::Time::Now() - base::Seconds(10);
   SegmentationUkmHelper::GetInstance()->RecordTrainingData(
-      optimization_guide::proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101,
-      input_tensors, outputs, output_indexes, GetPredictionResult(),
-      selected_segment);
-  ExpectUkmMetrics(
-      Segmentation_ModelExecution::kEntryName,
-      {Segmentation_ModelExecution::kOptimizationTargetName,
-       Segmentation_ModelExecution::kModelVersionName,
-       Segmentation_ModelExecution::kInput0Name,
-       Segmentation_ModelExecution::kActualResult3Name,
-       Segmentation_ModelExecution::kActualResult4Name,
-       Segmentation_ModelExecution::kPredictionResultName,
-       Segmentation_ModelExecution::kSelectionResultName,
-       Segmentation_ModelExecution::kOutputDelaySecName},
-      {
-          optimization_guide::proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-          101,
-          SegmentationUkmHelper::FloatToInt64(0.1),
-          SegmentationUkmHelper::FloatToInt64(1.0),
-          SegmentationUkmHelper::FloatToInt64(0.0),
-          SegmentationUkmHelper::FloatToInt64(0.5),
-          optimization_guide::proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-          10,
-      });
+      proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors,
+      outputs, output_indexes, GetPredictionResult(), selected_segment);
+  ExpectUkmMetrics(Segmentation_ModelExecution::kEntryName,
+                   {Segmentation_ModelExecution::kOptimizationTargetName,
+                    Segmentation_ModelExecution::kModelVersionName,
+                    Segmentation_ModelExecution::kInput0Name,
+                    Segmentation_ModelExecution::kActualResult3Name,
+                    Segmentation_ModelExecution::kActualResult4Name,
+                    Segmentation_ModelExecution::kPredictionResultName,
+                    Segmentation_ModelExecution::kSelectionResultName,
+                    Segmentation_ModelExecution::kOutputDelaySecName},
+                   {
+                       proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
+                       101,
+                       SegmentationUkmHelper::FloatToInt64(0.1),
+                       SegmentationUkmHelper::FloatToInt64(1.0),
+                       SegmentationUkmHelper::FloatToInt64(0.0),
+                       SegmentationUkmHelper::FloatToInt64(0.5),
+                       proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
+                       10,
+                   });
 }
 
 // Tests that recording is disabled if kSegmentationStructuredMetricsFeature
@@ -174,8 +169,7 @@
   DisableStructureMetrics();
   std::vector<float> input_tensors = {0.1, 0.7, 0.8, 0.5};
   SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
-      optimization_guide::proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101,
-      input_tensors, 0.6);
+      proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors, 0.6);
   ExpectEmptyUkmMetrics(Segmentation_ModelExecution::kEntryName);
 }
 
@@ -185,8 +179,7 @@
   InitializeAllowedSegmentIds("7, 8");
   std::vector<float> input_tensors = {0.1, 0.7, 0.8, 0.5};
   SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
-      optimization_guide::proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101,
-      input_tensors, 0.6);
+      proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors, 0.6);
   ExpectEmptyUkmMetrics(Segmentation_ModelExecution::kEntryName);
 }
 
@@ -221,8 +214,8 @@
   std::vector<float> input_tensors(100, 0.1);
   ukm::SourceId source_id =
       SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
-          optimization_guide::proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-          101, input_tensors, 0.6);
+          proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors,
+          0.6);
   ASSERT_EQ(source_id, ukm::kInvalidSourceId);
   tester.ExpectTotalCount(histogram_name, 1);
   ASSERT_EQ(tester.GetTotalSum(histogram_name), 100);
@@ -239,25 +232,22 @@
 
   ukm::SourceId source_id =
       SegmentationUkmHelper::GetInstance()->RecordTrainingData(
-          optimization_guide::proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-          101, input_tensors, outputs, output_indexes, GetPredictionResult(),
-          absl::nullopt);
+          proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors,
+          outputs, output_indexes, GetPredictionResult(), absl::nullopt);
   ASSERT_EQ(source_id, ukm::kInvalidSourceId);
 
   // output_indexes value too large.
   output_indexes = {100, 1000};
   source_id = SegmentationUkmHelper::GetInstance()->RecordTrainingData(
-      optimization_guide::proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101,
-      input_tensors, outputs, output_indexes, GetPredictionResult(),
-      absl::nullopt);
+      proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors,
+      outputs, output_indexes, GetPredictionResult(), absl::nullopt);
   ASSERT_EQ(source_id, ukm::kInvalidSourceId);
 
   // Valid outputs.
   output_indexes = {3, 0};
   source_id = SegmentationUkmHelper::GetInstance()->RecordTrainingData(
-      optimization_guide::proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101,
-      input_tensors, outputs, output_indexes, GetPredictionResult(),
-      absl::nullopt);
+      proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors,
+      outputs, output_indexes, GetPredictionResult(), absl::nullopt);
   ASSERT_NE(source_id, ukm::kInvalidSourceId);
 }
 
diff --git a/components/segmentation_platform/internal/selection/experimental_group_recorder.cc b/components/segmentation_platform/internal/selection/experimental_group_recorder.cc
index 43137e5..09b226f 100644
--- a/components/segmentation_platform/internal/selection/experimental_group_recorder.cc
+++ b/components/segmentation_platform/internal/selection/experimental_group_recorder.cc
@@ -20,7 +20,7 @@
     SegmentInfoDatabase* segment_database,
     FieldTrialRegister* field_trial_register,
     const std::string& segmentation_key,
-    optimization_guide::proto::OptimizationTarget selected_segment)
+    proto::SegmentId selected_segment)
     : field_trial_register_(field_trial_register),
       segmentation_key_(segmentation_key) {
   segment_database->GetSegmentInfo(
diff --git a/components/segmentation_platform/internal/selection/experimental_group_recorder.h b/components/segmentation_platform/internal/selection/experimental_group_recorder.h
index bf08ae19..4bf0b70 100644
--- a/components/segmentation_platform/internal/selection/experimental_group_recorder.h
+++ b/components/segmentation_platform/internal/selection/experimental_group_recorder.h
@@ -7,8 +7,8 @@
 
 #include "base/memory/raw_ptr.h"
 #include "base/memory/weak_ptr.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
 namespace segmentation_platform {
@@ -22,11 +22,10 @@
   // On construction, gets the model score from the database and records the
   // subsegment based on the score. This class must be kept alive till the
   // recording is complete, can be used only once.
-  ExperimentalGroupRecorder(
-      SegmentInfoDatabase* storage_service,
-      FieldTrialRegister* field_trial_register,
-      const std::string& segmentation_key,
-      optimization_guide::proto::OptimizationTarget selected_segment);
+  ExperimentalGroupRecorder(SegmentInfoDatabase* storage_service,
+                            FieldTrialRegister* field_trial_register,
+                            const std::string& segmentation_key,
+                            proto::SegmentId selected_segment);
   ~ExperimentalGroupRecorder();
 
   ExperimentalGroupRecorder(ExperimentalGroupRecorder&) = delete;
diff --git a/components/segmentation_platform/internal/selection/segment_result_provider.cc b/components/segmentation_platform/internal/selection/segment_result_provider.cc
index 4772680..eafba1e 100644
--- a/components/segmentation_platform/internal/selection/segment_result_provider.cc
+++ b/components/segmentation_platform/internal/selection/segment_result_provider.cc
@@ -28,7 +28,7 @@
       segmentation_key, segment_info.prediction_result().result(),
       segment_info.model_metadata());
   VLOG(1) << __func__
-          << ": segment=" << OptimizationTarget_Name(segment_info.segment_id())
+          << ": segment=" << SegmentId_Name(segment_info.segment_id())
           << ": result=" << segment_info.prediction_result().result()
           << ", rank=" << rank;
 
@@ -110,7 +110,7 @@
 };
 
 void SegmentResultProviderImpl::GetSegmentResult(GetResultOptions&& options) {
-  const OptimizationTarget segment_id = options.segment_id;
+  const SegmentId segment_id = options.segment_id;
   auto request_state = std::make_unique<RequestState>();
   request_state = std::make_unique<RequestState>();
   request_state->options = std::move(options);
@@ -135,8 +135,8 @@
   // Don't compute results if we don't have enough signals, or don't have
   // valid unexpired results for any of the segments.
   if (!db_segment_info) {
-    VLOG(1) << __func__ << ": segment="
-            << OptimizationTarget_Name(request_state->options.segment_id)
+    VLOG(1) << __func__
+            << ": segment=" << SegmentId_Name(request_state->options.segment_id)
             << " does not have segment info.";
     TryGetScoreFromDefaultModel(std::move(request_state),
                                 ResultState::kSegmentNotAvailable,
@@ -149,8 +149,8 @@
   if (!force_refresh_results_ &&
       !signal_storage_config_->MeetsSignalCollectionRequirement(
           db_segment_info->model_metadata())) {
-    VLOG(1) << __func__ << ": segment="
-            << OptimizationTarget_Name(db_segment_info->segment_id())
+    VLOG(1) << __func__
+            << ": segment=" << SegmentId_Name(db_segment_info->segment_id())
             << " does not meet signal collection requirements.";
     TryGetScoreFromDefaultModel(std::move(request_state),
                                 ResultState::kSignalsNotCollected,
@@ -159,8 +159,8 @@
   }
 
   if (request_state->options.ignore_db_scores) {
-    VLOG(1) << __func__ << ": segment="
-            << OptimizationTarget_Name(db_segment_info->segment_id())
+    VLOG(1) << __func__
+            << ": segment=" << SegmentId_Name(db_segment_info->segment_id())
             << " executing model to get score";
     TryExecuteModelAndGetScore(std::move(request_state),
                                std::move(available_segments));
@@ -169,8 +169,8 @@
 
   if (metadata_utils::HasExpiredOrUnavailableResult(*db_segment_info,
                                                     clock_->Now())) {
-    VLOG(1) << __func__ << ": segment="
-            << OptimizationTarget_Name(db_segment_info->segment_id())
+    VLOG(1) << __func__
+            << ": segment=" << SegmentId_Name(db_segment_info->segment_id())
             << " has expired or unavailable result.";
     TryGetScoreFromDefaultModel(std::move(request_state),
                                 ResultState::kDatabaseScoreNotReady,
@@ -214,8 +214,8 @@
     DefaultModelManager::SegmentInfoList available_segments) {
   if (!request_state->default_provider ||
       !request_state->default_provider->ModelAvailable()) {
-    VLOG(1) << __func__ << ": segment="
-            << OptimizationTarget_Name(request_state->options.segment_id)
+    VLOG(1) << __func__
+            << ": segment=" << SegmentId_Name(request_state->options.segment_id)
             << " default provider not available";
     PostResultCallback(std::move(request_state),
                        std::make_unique<SegmentResult>(existing_state));
@@ -225,8 +225,8 @@
   proto::SegmentInfo* segment_info =
       GetSegmentInfo(available_segments, /*default_model=*/true);
   if (!segment_info) {
-    VLOG(1) << __func__ << ": segment="
-            << OptimizationTarget_Name(request_state->options.segment_id)
+    VLOG(1) << __func__
+            << ": segment=" << SegmentId_Name(request_state->options.segment_id)
             << " default segment info not available";
     PostResultCallback(std::move(request_state),
                        std::make_unique<SegmentResult>(
@@ -243,8 +243,8 @@
   if (!force_refresh_results_ &&
       !signal_storage_config_->MeetsSignalCollectionRequirement(
           default_segment_info->model_metadata())) {
-    VLOG(1) << __func__ << ": segment="
-            << OptimizationTarget_Name(request_state->options.segment_id)
+    VLOG(1) << __func__
+            << ": segment=" << SegmentId_Name(request_state->options.segment_id)
             << " signal collection not met";
     PostResultCallback(std::move(request_state),
                        std::make_unique<SegmentResult>(
diff --git a/components/segmentation_platform/internal/selection/segment_result_provider.h b/components/segmentation_platform/internal/selection/segment_result_provider.h
index 2458763..883d9b20b 100644
--- a/components/segmentation_platform/internal/selection/segment_result_provider.h
+++ b/components/segmentation_platform/internal/selection/segment_result_provider.h
@@ -7,9 +7,9 @@
 
 #include "base/callback.h"
 #include "base/memory/scoped_refptr.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/segment_info_database.h"
 #include "components/segmentation_platform/internal/input_context.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
 namespace base {
@@ -75,8 +75,7 @@
     GetResultOptions& operator=(GetResultOptions&&);
 
     // The segment ID to fetch result for.
-    OptimizationTarget segment_id =
-        OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN;
+    SegmentId segment_id = SegmentId::OPTIMIZATION_TARGET_UNKNOWN;
 
     // The key is used for recording metrics only.
     std::string segmentation_key;
diff --git a/components/segmentation_platform/internal/selection/segment_result_provider_unittest.cc b/components/segmentation_platform/internal/selection/segment_result_provider_unittest.cc
index 2b91a51..6dc2450 100644
--- a/components/segmentation_platform/internal/selection/segment_result_provider_unittest.cc
+++ b/components/segmentation_platform/internal/selection/segment_result_provider_unittest.cc
@@ -29,10 +29,10 @@
 using ::testing::ByMove;
 using ::testing::Return;
 
-const OptimizationTarget kTestSegment =
-    OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
-const OptimizationTarget kTestSegment2 =
-    OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE;
+const SegmentId kTestSegment =
+    SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+const SegmentId kTestSegment2 =
+    SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE;
 
 constexpr float kTestScore = 0.1;
 constexpr float kDatabaseScore = 0.6;
@@ -42,14 +42,13 @@
 class TestModelProvider : public ModelProvider {
  public:
   static constexpr int64_t kVersion = 10;
-  explicit TestModelProvider(OptimizationTarget segment)
-      : ModelProvider(segment) {}
+  explicit TestModelProvider(SegmentId segment) : ModelProvider(segment) {}
 
   void InitAndFetchModel(
       const ModelUpdatedCallback& model_updated_callback) override {
     proto::SegmentationModelMetadata metadata;
     metadata.set_time_unit(proto::TimeUnit::DAY);
-    model_updated_callback.Run(optimization_target_, metadata, kVersion);
+    model_updated_callback.Run(segment_id_, metadata, kVersion);
   }
 
   void ExecuteModelWithInput(const std::vector<float>& inputs,
@@ -63,9 +62,7 @@
 
 class MockModelExecutionManager : public ModelExecutionManager {
  public:
-  MOCK_METHOD(ModelProvider*,
-              GetProvider,
-              (optimization_guide::proto::OptimizationTarget segment_id));
+  MOCK_METHOD(ModelProvider*, GetProvider, (proto::SegmentId segment_id));
 };
 
 }  // namespace
@@ -78,7 +75,7 @@
   void SetUp() override {
     default_manager_ = std::make_unique<DefaultModelManager>(
         &provider_factory_,
-        std::vector<OptimizationTarget>({kTestSegment, kTestSegment2}));
+        std::vector<SegmentId>({kTestSegment, kTestSegment2}));
     segment_database_ = std::make_unique<test::TestSegmentInfoDatabase>();
     execution_service_ = std::make_unique<ExecutionService>();
     auto query_processor =
@@ -104,7 +101,7 @@
   }
 
   void ExpectSegmentResultOnGet(
-      OptimizationTarget segment_id,
+      SegmentId segment_id,
       bool ignore_db_scores,
       SegmentResultProvider::ResultState expected_state,
       absl::optional<int> expected_rank) {
@@ -130,8 +127,7 @@
     wait_for_result.Run();
   }
 
-  void SetSegmentResult(OptimizationTarget segment,
-                        absl::optional<float> score) {
+  void SetSegmentResult(SegmentId segment, absl::optional<float> score) {
     absl::optional<proto::PredictionResult> result;
     if (score) {
       result = proto::PredictionResult();
@@ -146,7 +142,7 @@
     wait_for_save.Run();
   }
 
-  void InitializeMetadata(OptimizationTarget segment_id) {
+  void InitializeMetadata(SegmentId segment_id) {
     segment_database_->FindOrCreateSegment(segment_id)
         ->mutable_model_metadata()
         ->set_result_time_to_live(7);
@@ -274,7 +270,7 @@
 
 TEST_F(SegmentResultProviderTest, DefaultNeedsSignal) {
   SetSegmentResult(kTestSegment, absl::nullopt);
-  std::map<OptimizationTarget, std::unique_ptr<ModelProvider>> p;
+  std::map<SegmentId, std::unique_ptr<ModelProvider>> p;
   p.emplace(kTestSegment, std::make_unique<TestModelProvider>(kTestSegment));
   default_manager_->SetDefaultProvidersForTesting(std::move(p));
 
@@ -292,7 +288,7 @@
 
 TEST_F(SegmentResultProviderTest, DefaultModelFailedExecution) {
   SetSegmentResult(kTestSegment, absl::nullopt);
-  std::map<OptimizationTarget, std::unique_ptr<ModelProvider>> p;
+  std::map<SegmentId, std::unique_ptr<ModelProvider>> p;
   p.emplace(kTestSegment, std::make_unique<TestModelProvider>(kTestSegment));
   default_manager_->SetDefaultProvidersForTesting(std::move(p));
 
@@ -313,7 +309,7 @@
 
 TEST_F(SegmentResultProviderTest, GetFromDefault) {
   SetSegmentResult(kTestSegment, absl::nullopt);
-  std::map<OptimizationTarget, std::unique_ptr<ModelProvider>> p;
+  std::map<SegmentId, std::unique_ptr<ModelProvider>> p;
   p.emplace(kTestSegment, std::make_unique<TestModelProvider>(kTestSegment));
   default_manager_->SetDefaultProvidersForTesting(std::move(p));
 
@@ -334,7 +330,7 @@
   InitializeMetadata(kTestSegment2);
   SetSegmentResult(kTestSegment2, kDatabaseScore);
 
-  std::map<OptimizationTarget, std::unique_ptr<ModelProvider>> p;
+  std::map<SegmentId, std::unique_ptr<ModelProvider>> p;
   p.emplace(kTestSegment, std::make_unique<TestModelProvider>(kTestSegment));
   p.emplace(kTestSegment2, std::make_unique<TestModelProvider>(kTestSegment2));
   default_manager_->SetDefaultProvidersForTesting(std::move(p));
diff --git a/components/segmentation_platform/internal/selection/segment_score_provider.cc b/components/segmentation_platform/internal/selection/segment_score_provider.cc
index 6c05a9b..df50bac 100644
--- a/components/segmentation_platform/internal/selection/segment_score_provider.cc
+++ b/components/segmentation_platform/internal/selection/segment_score_provider.cc
@@ -30,7 +30,7 @@
                        weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
   }
 
-  void GetSegmentScore(OptimizationTarget segment_id,
+  void GetSegmentScore(SegmentId segment_id,
                        SegmentScoreCallback callback) override {
     DCHECK(initialized_);
 
@@ -49,7 +49,7 @@
       std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> all_segments) {
     // Read results from last session to memory.
     for (const auto& pair : *all_segments) {
-      OptimizationTarget id = pair.first;
+      SegmentId id = pair.first;
       const proto::SegmentInfo& info = pair.second;
       if (!info.has_prediction_result())
         continue;
@@ -67,7 +67,7 @@
 
   // Model scores that are read from db on startup and used for serving the
   // clients in the current session.
-  std::map<OptimizationTarget, float> scores_last_session_;
+  std::map<SegmentId, float> scores_last_session_;
 
   // Whether the initialization is complete through an Initialize call.
   bool initialized_{false};
diff --git a/components/segmentation_platform/internal/selection/segment_score_provider.h b/components/segmentation_platform/internal/selection/segment_score_provider.h
index dc69412..792809d 100644
--- a/components/segmentation_platform/internal/selection/segment_score_provider.h
+++ b/components/segmentation_platform/internal/selection/segment_score_provider.h
@@ -6,13 +6,13 @@
 #define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_SEGMENT_SCORE_PROVIDER_H_
 
 #include "base/callback.h"
-#include "components/optimization_guide/proto/models.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
-using optimization_guide::proto::OptimizationTarget;
-
 namespace segmentation_platform {
 
+using proto::SegmentId;
+
 class SegmentInfoDatabase;
 
 // Result of a single segment.
@@ -50,7 +50,7 @@
   // from the last session.
   // Note that there is no strong reason to keep this async, feel free to change
   // this to sync if needed.
-  virtual void GetSegmentScore(OptimizationTarget segment_id,
+  virtual void GetSegmentScore(SegmentId segment_id,
                                SegmentScoreCallback callback) = 0;
 };
 
diff --git a/components/segmentation_platform/internal/selection/segment_score_provider_unittest.cc b/components/segmentation_platform/internal/selection/segment_score_provider_unittest.cc
index 8fa0073..384ecc94 100644
--- a/components/segmentation_platform/internal/selection/segment_score_provider_unittest.cc
+++ b/components/segmentation_platform/internal/selection/segment_score_provider_unittest.cc
@@ -29,7 +29,7 @@
         SegmentScoreProvider::Create(segment_database_.get());
   }
 
-  void InitializeMetadataForSegment(OptimizationTarget segment_id,
+  void InitializeMetadataForSegment(SegmentId segment_id,
                                     float mapping[][2],
                                     int num_mapping_pairs) {
     auto* metadata = segment_database_->FindOrCreateSegment(segment_id)
@@ -43,8 +43,7 @@
         segment_id, mapping, num_mapping_pairs, default_mapping_key);
   }
 
-  void GetSegmentScore(OptimizationTarget segment_id,
-                       const SegmentScore& expected) {
+  void GetSegmentScore(SegmentId segment_id, const SegmentScore& expected) {
     base::RunLoop loop;
     single_segment_manager_->GetSegmentScore(
         segment_id,
@@ -66,8 +65,7 @@
 };
 
 TEST_F(SegmentScoreProviderTest, GetSegmentScore) {
-  OptimizationTarget segment_id1 =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+  SegmentId segment_id1 = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
   float mapping1[][2] = {{0.2, 1}, {0.5, 3}, {0.7, 4}};
   InitializeMetadataForSegment(segment_id1, mapping1, 3);
   segment_database_->AddPredictionResult(segment_id1, 0.6, base::Time::Now());
@@ -86,7 +84,7 @@
   GetSegmentScore(segment_id1, expected);
 
   // Returns empty results when called on a segment with no scores.
-  GetSegmentScore(OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE,
+  GetSegmentScore(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE,
                   SegmentScore());
 }
 
diff --git a/components/segmentation_platform/internal/selection/segment_selector.h b/components/segmentation_platform/internal/selection/segment_selector.h
index e1a4d646..a601823 100644
--- a/components/segmentation_platform/internal/selection/segment_selector.h
+++ b/components/segmentation_platform/internal/selection/segment_selector.h
@@ -7,15 +7,15 @@
 
 #include "base/callback.h"
 #include "base/memory/scoped_refptr.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/execution/model_execution_status.h"
 #include "components/segmentation_platform/internal/scheduler/model_execution_scheduler.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
-using optimization_guide::proto::OptimizationTarget;
-
 namespace segmentation_platform {
 
+using proto::SegmentId;
+
 struct InputContext;
 struct SegmentSelectionResult;
 class ExecutionService;
diff --git a/components/segmentation_platform/internal/selection/segment_selector_impl.cc b/components/segmentation_platform/internal/selection/segment_selector_impl.cc
index 6a53679..fa039feb 100644
--- a/components/segmentation_platform/internal/selection/segment_selector_impl.cc
+++ b/components/segmentation_platform/internal/selection/segment_selector_impl.cc
@@ -20,6 +20,7 @@
 #include "components/segmentation_platform/internal/platform_options.h"
 #include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
 #include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
+#include "components/segmentation_platform/internal/segment_id_convertor.h"
 #include "components/segmentation_platform/internal/selection/experimental_group_recorder.h"
 #include "components/segmentation_platform/internal/selection/segment_result_provider.h"
 #include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
@@ -66,7 +67,7 @@
 
 }  // namespace
 
-using optimization_guide::proto::OptimizationTarget_Name;
+using proto::SegmentId_Name;
 
 SegmentSelectorImpl::SegmentSelectorImpl(
     SegmentInfoDatabase* segment_database,
@@ -111,14 +112,15 @@
       stats::SegmentationKeyToTrialName(config_->segmentation_key);
   std::string group_name;
   if (selected_segment.has_value()) {
-    selected_segment_last_session_.segment = selected_segment->segment_id;
+    selected_segment_last_session_.segment =
+        SegmentIdToOptimizationTarget(selected_segment->segment_id);
     selected_segment_last_session_.is_ready = true;
     stats::RecordSegmentSelectionFailure(
         config_->segmentation_key,
         stats::SegmentationSelectionFailureReason::kSelectionAvailableInPrefs);
 
     group_name = stats::OptimizationTargetToSegmentGroupName(
-        *selected_segment_last_session_.segment);
+        selected_segment->segment_id);
   } else {
     stats::RecordSegmentSelectionFailure(
         config_->segmentation_key, stats::SegmentationSelectionFailureReason::
@@ -148,7 +150,7 @@
   // TODO(ssid): Store the scores in prefs so that this can be recorded earlier
   // in startup.
   if (selected_segment_last_session_.is_ready) {
-    for (const OptimizationTarget segment_id : config_->segment_ids) {
+    for (const SegmentId segment_id : config_->segment_ids) {
       experimental_group_recorder_.emplace_back(
           std::make_unique<ExperimentalGroupRecorder>(
               segment_database_, field_trial_register_,
@@ -176,8 +178,7 @@
                         std::move(callback));
 }
 
-void SegmentSelectorImpl::OnModelExecutionCompleted(
-    OptimizationTarget segment_id) {
+void SegmentSelectorImpl::OnModelExecutionCompleted(SegmentId segment_id) {
   DCHECK(segment_result_provider_);
 
   // If the |segment_id| is not in config, then skip any updates early.
@@ -196,7 +197,7 @@
       result_prefs_->ReadSegmentationResultFromPref(config_->segmentation_key);
   if (previous_selection.has_value()) {
     bool was_unknown_selected = previous_selection->segment_id ==
-                                OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN;
+                                SegmentId::OPTIMIZATION_TARGET_UNKNOWN;
     base::TimeDelta ttl_to_use = was_unknown_selected
                                      ? config_->unknown_selection_ttl
                                      : config_->segment_selection_ttl;
@@ -206,7 +207,7 @@
           config_->segmentation_key,
           stats::SegmentationSelectionFailureReason::kSelectionTtlNotExpired);
       VLOG(1) << __func__ << ": previous selection of segment="
-              << OptimizationTarget_Name(previous_selection->segment_id)
+              << SegmentId_Name(previous_selection->segment_id)
               << " has not yet expired.";
       return false;
     }
@@ -227,7 +228,7 @@
     std::unique_ptr<SegmentRanks> ranks,
     scoped_refptr<InputContext> input_context,
     SegmentSelectionCallback callback) {
-  for (OptimizationTarget needed_segment : config_->segment_ids) {
+  for (SegmentId needed_segment : config_->segment_ids) {
     if (ranks->count(needed_segment) == 0) {
       SegmentResultProvider::GetResultOptions options;
       options.segment_id = needed_segment;
@@ -244,12 +245,12 @@
   }
 
   // Finished fetching ranks for all segments.
-  OptimizationTarget selected_segment = FindBestSegment(*ranks);
+  SegmentId selected_segment = FindBestSegment(*ranks);
   if (config_->on_demand_execution) {
     DCHECK(!callback.is_null());
     SegmentSelectionResult result;
     result.is_ready = true;
-    result.segment = selected_segment;
+    result.segment = SegmentIdToOptimizationTarget(selected_segment);
     std::move(callback).Run(result);
   } else {
     DCHECK(callback.is_null());
@@ -261,7 +262,7 @@
     std::unique_ptr<SegmentRanks> ranks,
     scoped_refptr<InputContext> input_context,
     SegmentSelectionCallback callback,
-    OptimizationTarget current_segment_id,
+    SegmentId current_segment_id,
     std::unique_ptr<SegmentResultProvider::SegmentResult> result) {
   if (!result->rank) {
     stats::RecordSegmentSelectionFailure(config_->segmentation_key,
@@ -273,15 +274,14 @@
   GetRankForNextSegment(std::move(ranks), input_context, std::move(callback));
 }
 
-OptimizationTarget SegmentSelectorImpl::FindBestSegment(
+SegmentId SegmentSelectorImpl::FindBestSegment(
     const SegmentRanks& segment_results) {
   int max_rank = 0;
-  OptimizationTarget max_rank_id =
-      OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN;
+  SegmentId max_rank_id = SegmentId::OPTIMIZATION_TARGET_UNKNOWN;
   // Loop through all the results. Convert them to discrete ranks. Select the
   // one with highest discrete rank.
   for (const auto& pair : segment_results) {
-    OptimizationTarget id = pair.first;
+    SegmentId id = pair.first;
     int rank = pair.second;
     if (rank > max_rank) {
       max_rank = rank;
@@ -294,10 +294,9 @@
   return max_rank_id;
 }
 
-void SegmentSelectorImpl::UpdateSelectedSegment(
-    OptimizationTarget new_selection) {
-  VLOG(1) << __func__ << ": Updating selected segment="
-          << OptimizationTarget_Name(new_selection);
+void SegmentSelectorImpl::UpdateSelectedSegment(SegmentId new_selection) {
+  VLOG(1) << __func__
+          << ": Updating selected segment=" << SegmentId_Name(new_selection);
   const auto& previous_selection =
       result_prefs_->ReadSegmentationResultFromPref(config_->segmentation_key);
 
@@ -310,7 +309,7 @@
     skip_updating_prefs = new_selection == previous_selection->segment_id;
     skip_updating_prefs |=
         config_->unknown_selection_ttl == base::TimeDelta() &&
-        new_selection == OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN;
+        new_selection == SegmentId::OPTIMIZATION_TARGET_UNKNOWN;
     // TODO(shaktisahu): Use segment selection inertia.
   }
 
diff --git a/components/segmentation_platform/internal/selection/segment_selector_impl.h b/components/segmentation_platform/internal/selection/segment_selector_impl.h
index 1835f5f3..f40d020 100644
--- a/components/segmentation_platform/internal/selection/segment_selector_impl.h
+++ b/components/segmentation_platform/internal/selection/segment_selector_impl.h
@@ -62,11 +62,11 @@
 
   // Helper function to update the selected segment in the prefs. Auto-extends
   // the selection if the new result is unknown.
-  virtual void UpdateSelectedSegment(OptimizationTarget new_selection);
+  virtual void UpdateSelectedSegment(SegmentId new_selection);
 
   // Called whenever a model eval completes. Runs segment selection to find the
   // best segment, and writes it to the pref.
-  void OnModelExecutionCompleted(OptimizationTarget segment_id) override;
+  void OnModelExecutionCompleted(SegmentId segment_id) override;
 
   void set_segment_result_provider_for_testing(
       std::unique_ptr<SegmentResultProvider> result_provider) {
@@ -77,7 +77,7 @@
   // For testing.
   friend class SegmentSelectorTest;
 
-  using SegmentRanks = base::flat_map<OptimizationTarget, int>;
+  using SegmentRanks = base::flat_map<SegmentId, int>;
 
   // Determines whether segment selection can be run based on whether the
   // segment selection TTL has expired, or selection is unavailable.
@@ -98,13 +98,13 @@
       std::unique_ptr<SegmentRanks> ranks,
       scoped_refptr<InputContext> input_context,
       SegmentSelectionCallback callback,
-      OptimizationTarget current_segment_id,
+      SegmentId current_segment_id,
       std::unique_ptr<SegmentResultProvider::SegmentResult> result);
 
   // Loops through all segments, performs discrete mapping, honors finch
   // supplied tie-breakers, TTL, inertia etc, and finds the highest rank.
   // Ignores the segments that have no results.
-  OptimizationTarget FindBestSegment(const SegmentRanks& segment_scores);
+  SegmentId FindBestSegment(const SegmentRanks& segment_scores);
 
   std::unique_ptr<SegmentResultProvider> segment_result_provider_;
 
diff --git a/components/segmentation_platform/internal/selection/segment_selector_unittest.cc b/components/segmentation_platform/internal/selection/segment_selector_unittest.cc
index 27aa9a6..8e598b44 100644
--- a/components/segmentation_platform/internal/selection/segment_selector_unittest.cc
+++ b/components/segmentation_platform/internal/selection/segment_selector_unittest.cc
@@ -15,6 +15,7 @@
 #include "components/segmentation_platform/internal/execution/mock_model_provider.h"
 #include "components/segmentation_platform/internal/metadata/metadata_utils.h"
 #include "components/segmentation_platform/internal/metric_filter_utils.h"
+#include "components/segmentation_platform/internal/segment_id_convertor.h"
 #include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
 #include "components/segmentation_platform/public/config.h"
 #include "components/segmentation_platform/public/field_trial_register.h"
@@ -41,7 +42,7 @@
 
   MOCK_METHOD3(RegisterSubsegmentFieldTrialIfNeeded,
                void(base::StringPiece trial_name,
-                    optimization_guide::proto::OptimizationTarget segment_id,
+                    proto::SegmentId segment_id,
                     int subsegment_rank));
 };
 
@@ -50,9 +51,8 @@
   config.segmentation_key = "test_key";
   config.segment_selection_ttl = base::Days(28);
   config.unknown_selection_ttl = base::Days(14);
-  config.segment_ids = {
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE};
+  config.segment_ids = {SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
+                        SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE};
   return config;
 }
 
@@ -116,7 +116,7 @@
     std::move(closure).Run();
   }
 
-  void InitializeMetadataForSegment(OptimizationTarget segment_id,
+  void InitializeMetadataForSegment(SegmentId segment_id,
                                     float mapping[][2],
                                     int num_mapping_pairs) {
     EXPECT_CALL(signal_storage_config_, MeetsSignalCollectionRequirement(_, _))
@@ -130,7 +130,7 @@
         segment_id, mapping, num_mapping_pairs, config_.segmentation_key);
   }
 
-  void CompleteModelExecution(OptimizationTarget segment_id, float score) {
+  void CompleteModelExecution(SegmentId segment_id, float score) {
     segment_database_->AddPredictionResult(segment_id, score, clock_.Now());
     segment_selector_->OnModelExecutionCompleted(segment_id);
     task_environment_.RunUntilIdle();
@@ -154,13 +154,11 @@
   EXPECT_CALL(signal_storage_config_, MeetsSignalCollectionRequirement(_, _))
       .WillRepeatedly(Return(true));
 
-  OptimizationTarget segment_id =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+  SegmentId segment_id = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
   float mapping[][2] = {{0.2, 1}, {0.5, 3}, {0.7, 4}};
   InitializeMetadataForSegment(segment_id, mapping, 3);
 
-  OptimizationTarget segment_id2 =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
+  SegmentId segment_id2 = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
   float mapping2[][2] = {{0.3, 1}, {0.4, 4}};
   InitializeMetadataForSegment(segment_id2, mapping2, 2);
 
@@ -181,13 +179,13 @@
   EXPECT_CALL(signal_storage_config_, MeetsSignalCollectionRequirement(_, _))
       .WillRepeatedly(Return(true));
 
-  static constexpr OptimizationTarget kSegmentId =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+  static constexpr SegmentId kSegmentId =
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
   float mapping[][2] = {{0.2, 1}, {0.5, 3}, {0.7, 4}};
   InitializeMetadataForSegment(kSegmentId, mapping, 3);
 
-  static constexpr OptimizationTarget kSegmentId2 =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
+  static constexpr SegmentId kSegmentId2 =
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
   float mapping2[][2] = {{0.3, 1}, {0.4, 4}};
   InitializeMetadataForSegment(kSegmentId2, mapping2, 2);
 
@@ -214,7 +212,8 @@
       base::BindOnce(
           [](base::OnceClosure quit, const SegmentSelectionResult& result) {
             EXPECT_TRUE(result.is_ready);
-            EXPECT_EQ(kSegmentId2, result.segment);
+            EXPECT_EQ(kSegmentId2,
+                      OptimizationTargetToSegmentId(*result.segment));
             std::move(quit).Run();
           },
           wait_for_selection.QuitClosure()));
@@ -227,13 +226,11 @@
   SetUpWithConfig(config);
 
   // Setup test with two models.
-  OptimizationTarget segment_id1 =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+  SegmentId segment_id1 = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
   float mapping1[][2] = {{0.2, 1}, {0.5, 3}, {0.7, 4}, {0.8, 5}};
   InitializeMetadataForSegment(segment_id1, mapping1, 4);
 
-  OptimizationTarget segment_id2 =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
+  SegmentId segment_id2 = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
   float mapping2[][2] = {{0.3, 1}, {0.4, 4}};
   InitializeMetadataForSegment(segment_id2, mapping2, 2);
 
@@ -247,13 +244,13 @@
 
   CompleteModelExecution(segment_id2, 0.1);
   ASSERT_TRUE(prefs_->selection.has_value());
-  ASSERT_EQ(OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN,
+  ASSERT_EQ(SegmentId::OPTIMIZATION_TARGET_UNKNOWN,
             prefs_->selection->segment_id);
 
   // Model 1 completes with a good score. Model 2 results are expired.
   clock_.Advance(config_.segment_selection_ttl * 1.2f);
   CompleteModelExecution(segment_id1, 0.6);
-  ASSERT_EQ(OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN,
+  ASSERT_EQ(SegmentId::OPTIMIZATION_TARGET_UNKNOWN,
             prefs_->selection->segment_id);
 
   // Model 2 gets fresh results. Now segment selection will update.
@@ -289,12 +286,11 @@
 TEST_F(SegmentSelectorTest, UnknownSegmentTtlExpiryForBooleanModel) {
   Config config = CreateTestConfig();
   config.segment_ids = {
-      OptimizationTarget::
-          OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID};
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID};
   SetUpWithConfig(config);
 
-  OptimizationTarget segment_id =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID;
+  SegmentId segment_id =
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID;
   float mapping[][2] = {{0.7, 1}};
   InitializeMetadataForSegment(segment_id, mapping, 1);
 
@@ -304,42 +300,39 @@
   // Set a value less than 1 and result should be UNKNOWN.
   CompleteModelExecution(segment_id, 0);
   ASSERT_TRUE(prefs_->selection.has_value());
-  ASSERT_EQ(OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN,
+  ASSERT_EQ(SegmentId::OPTIMIZATION_TARGET_UNKNOWN,
             prefs_->selection->segment_id);
 
   // Advance by less than UNKNOWN segment TTL and result should not change,
   // UNKNOWN segment TTL is less than selection TTL.
   clock_.Advance(config_.unknown_selection_ttl * 0.8f);
   CompleteModelExecution(segment_id, 0.9);
-  ASSERT_EQ(OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN,
+  ASSERT_EQ(SegmentId::OPTIMIZATION_TARGET_UNKNOWN,
             prefs_->selection->segment_id);
 
   // Advance clock so that the time is between UNKNOWN segment TTL and selection
   // TTL.
   clock_.Advance(config_.unknown_selection_ttl * 0.4f);
   CompleteModelExecution(segment_id, 0.9);
-  ASSERT_EQ(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID,
-      prefs_->selection->segment_id);
+  ASSERT_EQ(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID,
+            prefs_->selection->segment_id);
 
   // Advance by more than UNKNOWN segment TTL and result should not change.
   clock_.Advance(config_.unknown_selection_ttl * 1.2f);
   CompleteModelExecution(segment_id, 0);
-  ASSERT_EQ(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID,
-      prefs_->selection->segment_id);
+  ASSERT_EQ(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID,
+            prefs_->selection->segment_id);
 
   // Advance by segment selection TTL and result should change.
   clock_.Advance(config_.segment_selection_ttl * 1.2f);
   CompleteModelExecution(segment_id, 0);
-  ASSERT_EQ(OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN,
+  ASSERT_EQ(SegmentId::OPTIMIZATION_TARGET_UNKNOWN,
             prefs_->selection->segment_id);
 }
 
 TEST_F(SegmentSelectorTest, DoesNotMeetSignalCollectionRequirement) {
   SetUpWithConfig(CreateTestConfig());
-  OptimizationTarget segment_id1 =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+  SegmentId segment_id1 = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
   float mapping1[][2] = {{0.2, 1}, {0.5, 3}, {0.7, 4}, {0.8, 5}};
 
   segment_database_->FindOrCreateSegment(segment_id1)
@@ -361,10 +354,8 @@
   SetUpWithConfig(CreateTestConfig());
   EXPECT_CALL(signal_storage_config_, MeetsSignalCollectionRequirement(_, _))
       .WillRepeatedly(Return(true));
-  OptimizationTarget segment_id0 =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
-  OptimizationTarget segment_id1 =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+  SegmentId segment_id0 = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
+  SegmentId segment_id1 = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
   float mapping0[][2] = {{1.0, 0}};
   float mapping1[][2] = {{0.2, 1}, {0.5, 3}, {0.7, 4}};
   InitializeMetadataForSegment(segment_id0, mapping0, 1);
@@ -384,7 +375,7 @@
   segment_selector_->OnPlatformInitialized(nullptr);
 
   SegmentSelectionResult result;
-  result.segment = segment_id0;
+  result.segment = SegmentIdToOptimizationTarget(segment_id0);
   result.is_ready = true;
   GetSelectedSegment(result);
   ASSERT_EQ(result, segment_selector_->GetCachedSegmentResult());
@@ -410,13 +401,11 @@
   EXPECT_CALL(signal_storage_config_, MeetsSignalCollectionRequirement(_, _))
       .WillRepeatedly(Return(true));
 
-  OptimizationTarget segment_id =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+  SegmentId segment_id = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
   float mapping[][2] = {{0.2, 1}, {0.5, 3}, {0.7, 4}};
   InitializeMetadataForSegment(segment_id, mapping, 3);
 
-  OptimizationTarget segment_id2 =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
+  SegmentId segment_id2 = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
   float mapping2[][2] = {{0.3, 1}, {0.4, 4}};
   InitializeMetadataForSegment(segment_id2, mapping2, 2);
 
@@ -436,8 +425,8 @@
 }
 
 TEST_F(SegmentSelectorTest, SubsegmentRecording) {
-  const OptimizationTarget kSubsegmentEnabledTarget =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_FEED_USER;
+  const SegmentId kSubsegmentEnabledTarget =
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_FEED_USER;
 
   // Create config with Feed segment.
   Config config = CreateTestConfig();
@@ -450,8 +439,7 @@
   SetUpWithConfig(config);
 
   // Store model metadata, model scores and selection results.
-  OptimizationTarget segment_id0 =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
+  SegmentId segment_id0 = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
   float mapping0[][2] = {{1.0, 0}};
   InitializeMetadataForSegment(segment_id0, mapping0, 1);
   segment_database_->AddPredictionResult(segment_id0, 0.7, clock_.Now());
@@ -495,8 +483,8 @@
               RegisterSubsegmentFieldTrialIfNeeded(
                   base::StringPiece("Segmentation_TestKey_FeedUserSegment"),
                   kSubsegmentEnabledTarget, 3))
-      .WillOnce(Invoke(
-          [&wait_for_subsegment](base::StringPiece, OptimizationTarget, int) {
+      .WillOnce(
+          Invoke([&wait_for_subsegment](base::StringPiece, SegmentId, int) {
             wait_for_subsegment.QuitClosure().Run();
           }));
 
diff --git a/components/segmentation_platform/internal/selection/segmentation_result_prefs.cc b/components/segmentation_platform/internal/selection/segmentation_result_prefs.cc
index e5f87e296..91c669c 100644
--- a/components/segmentation_platform/internal/selection/segmentation_result_prefs.cc
+++ b/components/segmentation_platform/internal/selection/segmentation_result_prefs.cc
@@ -12,7 +12,7 @@
 
 namespace segmentation_platform {
 
-SelectedSegment::SelectedSegment(OptimizationTarget segment_id)
+SelectedSegment::SelectedSegment(SegmentId segment_id)
     : segment_id(segment_id), in_use(false) {}
 
 SelectedSegment::~SelectedSegment() = default;
@@ -57,8 +57,7 @@
   absl::optional<base::Time> selection_time =
       base::ValueToTime(segmentation_result.FindPath("selection_time"));
 
-  SelectedSegment selected_segment(
-      static_cast<OptimizationTarget>(segment_id.value()));
+  SelectedSegment selected_segment(static_cast<SegmentId>(segment_id.value()));
   if (in_use.has_value())
     selected_segment.in_use = in_use.value();
   if (selection_time.has_value())
diff --git a/components/segmentation_platform/internal/selection/segmentation_result_prefs.h b/components/segmentation_platform/internal/selection/segmentation_result_prefs.h
index e966dd7..a48ae879 100644
--- a/components/segmentation_platform/internal/selection/segmentation_result_prefs.h
+++ b/components/segmentation_platform/internal/selection/segmentation_result_prefs.h
@@ -7,24 +7,24 @@
 
 #include "base/memory/raw_ptr.h"
 #include "base/time/time.h"
-#include "components/optimization_guide/proto/models.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
-using optimization_guide::proto::OptimizationTarget;
-
 class PrefService;
 
 namespace segmentation_platform {
 
+using proto::SegmentId;
+
 // Struct containing information about the selected segment. Convenient for
 // reading and writing to prefs.
 struct SelectedSegment {
  public:
-  explicit SelectedSegment(OptimizationTarget segment_id);
+  explicit SelectedSegment(SegmentId segment_id);
   ~SelectedSegment();
 
   // The segment selection result.
-  OptimizationTarget segment_id;
+  SegmentId segment_id;
 
   // The time when the segment was selected.
   base::Time selection_time;
diff --git a/components/segmentation_platform/internal/selection/segmentation_result_prefs_unittest.cc b/components/segmentation_platform/internal/selection/segmentation_result_prefs_unittest.cc
index c105671..386ee18 100644
--- a/components/segmentation_platform/internal/selection/segmentation_result_prefs_unittest.cc
+++ b/components/segmentation_platform/internal/selection/segmentation_result_prefs_unittest.cc
@@ -39,8 +39,7 @@
   EXPECT_FALSE(current_result.has_value());
 
   // Save a result. Verify by reading the result back.
-  OptimizationTarget segment_id =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+  SegmentId segment_id = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
   SelectedSegment selected_segment(segment_id);
   result_prefs_->SaveSegmentationResultToPref(result_key, selected_segment);
   current_result = result_prefs_->ReadSegmentationResultFromPref(result_key);
@@ -51,7 +50,7 @@
 
   // Overwrite the result with a new segment.
   selected_segment.segment_id =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
   selected_segment.in_use = true;
   base::Time now = base::Time::Now();
   selected_segment.selection_time = now;
@@ -66,7 +65,7 @@
   // first key.
   std::string result_key2 = "some_key2";
   selected_segment.segment_id =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE;
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE;
   result_prefs_->SaveSegmentationResultToPref(result_key2, selected_segment);
   current_result = result_prefs_->ReadSegmentationResultFromPref(result_key2);
   EXPECT_TRUE(current_result.has_value());
@@ -74,7 +73,7 @@
 
   current_result = result_prefs_->ReadSegmentationResultFromPref(result_key);
   EXPECT_TRUE(current_result.has_value());
-  EXPECT_EQ(OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
+  EXPECT_EQ(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
             current_result->segment_id);
 
   // Save empty result. It should delete the current result.
@@ -85,7 +84,7 @@
 
 TEST_F(SegmentationResultPrefsTest, CorruptedValue) {
   std::string result_key = "some_key";
-  SelectedSegment selected_segment(static_cast<OptimizationTarget>(100));
+  SelectedSegment selected_segment(static_cast<SegmentId>(100));
   result_prefs_->SaveSegmentationResultToPref(result_key, selected_segment);
   absl::optional<SelectedSegment> current_result =
       result_prefs_->ReadSegmentationResultFromPref(result_key);
diff --git a/components/segmentation_platform/internal/service_proxy_impl.cc b/components/segmentation_platform/internal/service_proxy_impl.cc
index 93e6d829..2bcb0da 100644
--- a/components/segmentation_platform/internal/service_proxy_impl.cc
+++ b/components/segmentation_platform/internal/service_proxy_impl.cc
@@ -14,6 +14,7 @@
 #include "components/segmentation_platform/internal/database/signal_storage_config.h"
 #include "components/segmentation_platform/internal/metadata/metadata_utils.h"
 #include "components/segmentation_platform/internal/scheduler/execution_service.h"
+#include "components/segmentation_platform/internal/segment_id_convertor.h"
 #include "components/segmentation_platform/internal/segmentation_platform_service_impl.h"
 #include "components/segmentation_platform/internal/selection/segment_selector_impl.h"
 #include "components/segmentation_platform/public/config.h"
@@ -107,9 +108,9 @@
   UpdateObservers(true /* update_service_status */);
 }
 
-void ServiceProxyImpl::ExecuteModel(OptimizationTarget segment_id) {
+void ServiceProxyImpl::ExecuteModel(SegmentId segment_id) {
   if (!execution_service ||
-      segment_id == OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN) {
+      segment_id == SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
     return;
   }
   segment_db_->GetSegmentInfo(
@@ -129,27 +130,26 @@
   execution_service->RequestModelExecution(std::move(request));
 }
 
-void ServiceProxyImpl::OverwriteResult(OptimizationTarget segment_id,
-                                       float result) {
+void ServiceProxyImpl::OverwriteResult(SegmentId segment_id, float result) {
   if (!execution_service)
     return;
 
   if (result < 0 || result > 1)
     return;
 
-  if (segment_id != OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN) {
+  if (segment_id != SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
     execution_service->OverwriteModelExecutionResult(
         segment_id, std::make_pair(result, ModelExecutionStatus::kSuccess));
   }
 }
 
 void ServiceProxyImpl::SetSelectedSegment(const std::string& segmentation_key,
-                                          OptimizationTarget segment_id) {
+                                          SegmentId segment_id) {
   if (!segment_selectors_ ||
       segment_selectors_->find(segmentation_key) == segment_selectors_->end()) {
     return;
   }
-  if (segment_id != OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN) {
+  if (segment_id != SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
     auto& selector = segment_selectors_->at(segmentation_key);
     selector->UpdateSelectedSegment(segment_id);
   }
@@ -161,31 +161,30 @@
     return;
 
   // Convert the |segment_info| vector to a map for quick lookup.
-  base::flat_map<OptimizationTarget, proto::SegmentInfo> optimization_targets;
+  base::flat_map<SegmentId, proto::SegmentInfo> segment_ids;
   for (const auto& info : *segment_info) {
-    optimization_targets[info.first] = info.second;
+    segment_ids[info.first] = info.second;
   }
 
   std::vector<ServiceProxy::ClientInfo> result;
   for (const auto& config : *configs_) {
-    OptimizationTarget selected =
-        OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN;
+    SegmentId selected = SegmentId::OPTIMIZATION_TARGET_UNKNOWN;
     if (segment_selectors_ &&
         segment_selectors_->find(config->segmentation_key) !=
             segment_selectors_->end()) {
-      absl::optional<optimization_guide::proto::OptimizationTarget> target =
+      absl::optional<OptimizationTarget> target =
           segment_selectors_->at(config->segmentation_key)
               ->GetCachedSegmentResult()
               .segment;
       if (target.has_value()) {
-        selected = target.value();
+        selected = OptimizationTargetToSegmentId(*target);
       }
     }
     result.emplace_back(config->segmentation_key, selected);
     for (const auto& segment_id : config->segment_ids) {
-      if (!optimization_targets.contains(segment_id))
+      if (!segment_ids.contains(segment_id))
         continue;
-      const auto& info = optimization_targets[segment_id];
+      const auto& info = segment_ids[segment_id];
       result.back().segment_status.emplace_back(
           segment_id, SegmentMetadataToString(info),
           PredictionResultToString(info),
@@ -200,8 +199,7 @@
     obs.OnClientInfoAvailable(result);
 }
 
-void ServiceProxyImpl::OnModelExecutionCompleted(
-    OptimizationTarget segment_id) {
+void ServiceProxyImpl::OnModelExecutionCompleted(SegmentId segment_id) {
   // Update the observers with the new execution results.
   UpdateObservers(false);
 }
diff --git a/components/segmentation_platform/internal/service_proxy_impl.h b/components/segmentation_platform/internal/service_proxy_impl.h
index 96d521f..1d4ea7fc 100644
--- a/components/segmentation_platform/internal/service_proxy_impl.h
+++ b/components/segmentation_platform/internal/service_proxy_impl.h
@@ -12,14 +12,14 @@
 #include "base/memory/raw_ptr.h"
 #include "base/observer_list.h"
 #include "components/leveldb_proto/public/proto_database.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/database/segment_info_database.h"
 #include "components/segmentation_platform/internal/scheduler/model_execution_scheduler.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "components/segmentation_platform/public/service_proxy.h"
 
-using optimization_guide::proto::OptimizationTarget;
-
 namespace segmentation_platform {
+using proto::SegmentId;
+
 struct Config;
 class SignalStorageConfig;
 class ExecutionService;
@@ -48,10 +48,10 @@
 
   // ServiceProxy impl.
   void GetServiceStatus() override;
-  void ExecuteModel(OptimizationTarget segment_id) override;
-  void OverwriteResult(OptimizationTarget segment_id, float result) override;
+  void ExecuteModel(SegmentId segment_id) override;
+  void OverwriteResult(SegmentId segment_id, float result) override;
   void SetSelectedSegment(const std::string& segmentation_key,
-                          OptimizationTarget segment_id) override;
+                          SegmentId segment_id) override;
 
   // Called when segmentation service status changed.
   void OnServiceStatusChanged(bool is_initialized, int status_flag);
@@ -70,7 +70,7 @@
       std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_info);
 
   // ModelExecutionScheduler::Observer overrides.
-  void OnModelExecutionCompleted(OptimizationTarget segment_id) override;
+  void OnModelExecutionCompleted(SegmentId segment_id) override;
 
   bool is_service_initialized_ = false;
   int service_status_flag_ = 0;
diff --git a/components/segmentation_platform/internal/service_proxy_impl_unittest.cc b/components/segmentation_platform/internal/service_proxy_impl_unittest.cc
index 9e2aeb6e..a2acb02 100644
--- a/components/segmentation_platform/internal/service_proxy_impl_unittest.cc
+++ b/components/segmentation_platform/internal/service_proxy_impl_unittest.cc
@@ -38,7 +38,7 @@
 proto::SegmentInfo AddSegmentInfo(
     std::map<std::string, proto::SegmentInfo>* db_entries,
     Config* config,
-    OptimizationTarget segment_id) {
+    SegmentId segment_id) {
   proto::SegmentInfo info;
   info.set_segment_id(segment_id);
   db_entries->insert(
@@ -55,8 +55,7 @@
   MOCK_METHOD(void, RequestModelExecutionForEligibleSegments, (bool));
   MOCK_METHOD(void,
               OnModelExecutionCompleted,
-              (OptimizationTarget,
-               (const std::pair<float, ModelExecutionStatus>&)));
+              (SegmentId, (const std::pair<float, ModelExecutionStatus>&)));
 };
 
 }  // namespace
@@ -75,14 +74,14 @@
                             nullptr) {}
   ~FakeSegmentSelectorImpl() override = default;
 
-  void UpdateSelectedSegment(OptimizationTarget new_selection) override {
+  void UpdateSelectedSegment(SegmentId new_selection) override {
     new_selection_ = new_selection;
   }
 
-  OptimizationTarget new_selection() const { return new_selection_; }
+  SegmentId new_selection() const { return new_selection_; }
 
  private:
-  OptimizationTarget new_selection_;
+  SegmentId new_selection_;
 };
 
 class ServiceProxyImplTest : public testing::Test,
@@ -169,9 +168,9 @@
 }
 
 TEST_F(ServiceProxyImplTest, GetSegmentationInfoFromDB) {
-  proto::SegmentInfo info = AddSegmentInfo(
-      &db_entries_, configs_.at(0).get(),
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
+  proto::SegmentInfo info =
+      AddSegmentInfo(&db_entries_, configs_.at(0).get(),
+                     SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
   SetUpProxy();
 
   service_proxy_impl_->OnServiceStatusChanged(true, 7);
@@ -181,23 +180,23 @@
   ASSERT_EQ(client_info_.at(0).segment_status.size(), 1u);
   ServiceProxy::SegmentStatus status = client_info_.at(0).segment_status.at(0);
   ASSERT_EQ(status.segment_id,
-            OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
+            SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
   ASSERT_EQ(status.can_execute_segment, false);
   ASSERT_TRUE(status.segment_metadata.empty());
   ASSERT_TRUE(status.prediction_result.empty());
 }
 
 TEST_F(ServiceProxyImplTest, ExecuteModel) {
-  proto::SegmentInfo info = AddSegmentInfo(
-      &db_entries_, configs_.at(0).get(),
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
+  proto::SegmentInfo info =
+      AddSegmentInfo(&db_entries_, configs_.at(0).get(),
+                     SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
   SetUpProxy();
 
   service_proxy_impl_->OnServiceStatusChanged(true, 7);
   db_->LoadCallback(true);
 
   segment_db_->UpdateSegment(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, info,
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, info,
       base::DoNothing());
   db_->UpdateCallback(true);
 
@@ -210,7 +209,7 @@
   // Scheduler is not set, ExecuteModel() will do nothing.
   EXPECT_CALL(*scheduler, RequestModelExecution(_)).Times(0);
   service_proxy_impl_->ExecuteModel(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
 
   service_proxy_impl_->SetExecutionService(&execution);
   base::RunLoop wait_for_execution;
@@ -221,19 +220,18 @@
             wait_for_execution.QuitClosure().Run();
           }));
   service_proxy_impl_->ExecuteModel(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
   db_->GetCallback(true);
   wait_for_execution.Run();
 
   EXPECT_CALL(*scheduler, RequestModelExecution(_)).Times(0);
-  service_proxy_impl_->ExecuteModel(
-      OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN);
+  service_proxy_impl_->ExecuteModel(SegmentId::OPTIMIZATION_TARGET_UNKNOWN);
 }
 
 TEST_F(ServiceProxyImplTest, OverwriteResult) {
-  proto::SegmentInfo info = AddSegmentInfo(
-      &db_entries_, configs_.at(0).get(),
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
+  proto::SegmentInfo info =
+      AddSegmentInfo(&db_entries_, configs_.at(0).get(),
+                     SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
   SetUpProxy();
 
   service_proxy_impl_->OnServiceStatusChanged(true, 7);
@@ -248,33 +246,32 @@
   // Scheduler is not set, OverwriteValue() will do nothing.
   EXPECT_CALL(*scheduler, OnModelExecutionCompleted(_, _)).Times(0);
   service_proxy_impl_->OverwriteResult(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 0.7);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 0.7);
 
   // Test with invalid values.
   service_proxy_impl_->SetExecutionService(&execution);
   EXPECT_CALL(*scheduler, OnModelExecutionCompleted(_, _)).Times(0);
   service_proxy_impl_->OverwriteResult(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 1.1);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 1.1);
   service_proxy_impl_->OverwriteResult(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, -0.1);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, -0.1);
 
-  EXPECT_CALL(
-      *scheduler,
-      OnModelExecutionCompleted(
-          OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, _))
+  EXPECT_CALL(*scheduler,
+              OnModelExecutionCompleted(
+                  SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, _))
       .Times(1);
   service_proxy_impl_->OverwriteResult(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 0.7);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 0.7);
 
   EXPECT_CALL(*scheduler, OnModelExecutionCompleted(_, _)).Times(0);
-  service_proxy_impl_->OverwriteResult(
-      OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN, 0.7);
+  service_proxy_impl_->OverwriteResult(SegmentId::OPTIMIZATION_TARGET_UNKNOWN,
+                                       0.7);
 }
 
 TEST_F(ServiceProxyImplTest, SetSelectSegment) {
-  proto::SegmentInfo info = AddSegmentInfo(
-      &db_entries_, configs_.at(0).get(),
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
+  proto::SegmentInfo info =
+      AddSegmentInfo(&db_entries_, configs_.at(0).get(),
+                     SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
   SetUpProxy();
 
   service_proxy_impl_->OnServiceStatusChanged(true, 7);
@@ -282,8 +279,8 @@
 
   service_proxy_impl_->SetSelectedSegment(
       kTestSegmentationKey,
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
-  ASSERT_EQ(OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
+  ASSERT_EQ(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
             static_cast<FakeSegmentSelectorImpl*>(
                 segment_selectors_[kTestSegmentationKey].get())
                 ->new_selection());
diff --git a/components/segmentation_platform/internal/signals/history_service_observer.cc b/components/segmentation_platform/internal/signals/history_service_observer.cc
index 6996e511..4a61939 100644
--- a/components/segmentation_platform/internal/signals/history_service_observer.cc
+++ b/components/segmentation_platform/internal/signals/history_service_observer.cc
@@ -70,8 +70,7 @@
 }
 
 void HistoryServiceObserver::SetHistoryBasedSegments(
-    base::flat_set<optimization_guide::proto::OptimizationTarget>&&
-        history_based_segments) {
+    base::flat_set<proto::SegmentId>&& history_based_segments) {
   history_based_segments_ = std::move(history_based_segments);
   // If a delete is pending, clear the results now.
   if (pending_deletion_based_on_history_based_segments_) {
diff --git a/components/segmentation_platform/internal/signals/history_service_observer.h b/components/segmentation_platform/internal/signals/history_service_observer.h
index 5211e7d..75adbcc 100644
--- a/components/segmentation_platform/internal/signals/history_service_observer.h
+++ b/components/segmentation_platform/internal/signals/history_service_observer.h
@@ -12,7 +12,7 @@
 #include "base/time/time.h"
 #include "components/history/core/browser/history_service.h"
 #include "components/history/core/browser/history_service_observer.h"
-#include "components/optimization_guide/proto/models.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
 namespace segmentation_platform {
@@ -44,8 +44,7 @@
 
   // Sets the list of segment IDs that are based on history data.
   virtual void SetHistoryBasedSegments(
-      base::flat_set<optimization_guide::proto::OptimizationTarget>&&
-          history_based_segments);
+      base::flat_set<proto::SegmentId>&& history_based_segments);
 
  private:
   void DeleteResultsForHistoryBasedSegments();
@@ -55,8 +54,7 @@
 
   // List of segment IDs that depend on history data, that will be cleared when
   // history is deleted.
-  absl::optional<base::flat_set<optimization_guide::proto::OptimizationTarget>>
-      history_based_segments_;
+  absl::optional<base::flat_set<proto::SegmentId>> history_based_segments_;
   bool pending_deletion_based_on_history_based_segments_ = false;
 
   base::RepeatingClosure models_refresh_callback_;
diff --git a/components/segmentation_platform/internal/signals/signal_filter_processor.cc b/components/segmentation_platform/internal/signals/signal_filter_processor.cc
index 6ddb2c36..32262b0 100644
--- a/components/segmentation_platform/internal/signals/signal_filter_processor.cc
+++ b/components/segmentation_platform/internal/signals/signal_filter_processor.cc
@@ -42,7 +42,7 @@
   std::set<uint64_t> user_actions;
   std::set<std::pair<std::string, proto::SignalType>> histograms;
   UkmConfig ukm_config;
-  base::flat_set<OptimizationTarget> history_based_segments;
+  base::flat_set<SegmentId> history_based_segments;
 
  private:
   void AddUmaFeatures(const proto::SegmentationModelMetadata& metadata) {
@@ -95,7 +95,7 @@
     UserActionSignalHandler* user_action_signal_handler,
     HistogramSignalHandler* histogram_signal_handler,
     HistoryServiceObserver* history_observer,
-    const std::vector<OptimizationTarget>& segment_ids)
+    const std::vector<SegmentId>& segment_ids)
     : storage_service_(storage_service),
       user_action_signal_handler_(user_action_signal_handler),
       histogram_signal_handler_(histogram_signal_handler),
diff --git a/components/segmentation_platform/internal/signals/signal_filter_processor.h b/components/segmentation_platform/internal/signals/signal_filter_processor.h
index 26b4232..ff39f4c 100644
--- a/components/segmentation_platform/internal/signals/signal_filter_processor.h
+++ b/components/segmentation_platform/internal/signals/signal_filter_processor.h
@@ -7,13 +7,13 @@
 
 #include "base/memory/raw_ptr.h"
 #include "base/memory/weak_ptr.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/execution/default_model_manager.h"
-
-using optimization_guide::proto::OptimizationTarget;
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 
 namespace segmentation_platform {
 
+using proto::SegmentId;
+
 class HistogramSignalHandler;
 class HistoryServiceObserver;
 class StorageService;
@@ -28,7 +28,7 @@
                         UserActionSignalHandler* user_action_signal_handler,
                         HistogramSignalHandler* histogram_signal_handler,
                         HistoryServiceObserver* history_observer,
-                        const std::vector<OptimizationTarget>& segment_ids);
+                        const std::vector<SegmentId>& segment_ids);
   ~SignalFilterProcessor();
 
   // Disallow copy/assign.
@@ -53,7 +53,7 @@
   const raw_ptr<UserActionSignalHandler> user_action_signal_handler_;
   const raw_ptr<HistogramSignalHandler> histogram_signal_handler_;
   const raw_ptr<HistoryServiceObserver> history_observer_;
-  std::vector<OptimizationTarget> segment_ids_;
+  std::vector<SegmentId> segment_ids_;
 
   base::WeakPtrFactory<SignalFilterProcessor> weak_ptr_factory_{this};
 };
diff --git a/components/segmentation_platform/internal/signals/signal_filter_processor_unittest.cc b/components/segmentation_platform/internal/signals/signal_filter_processor_unittest.cc
index d895da33..126c51c 100644
--- a/components/segmentation_platform/internal/signals/signal_filter_processor_unittest.cc
+++ b/components/segmentation_platform/internal/signals/signal_filter_processor_unittest.cc
@@ -53,21 +53,19 @@
 
 class MockHistoryObserver : public HistoryServiceObserver {
  public:
-  MOCK_METHOD1(
-      SetHistoryBasedSegments,
-      void(base::flat_set<optimization_guide::proto::OptimizationTarget>&&
-               history_based_segments));
+  MOCK_METHOD1(SetHistoryBasedSegments,
+               void(base::flat_set<proto::SegmentId>&& history_based_segments));
 };
 
 // Noop version. For database calls, just passes the calls to the DB.
 class TestDefaultModelManager : public DefaultModelManager {
  public:
   TestDefaultModelManager()
-      : DefaultModelManager(nullptr, std::vector<OptimizationTarget>()) {}
+      : DefaultModelManager(nullptr, std::vector<SegmentId>()) {}
   ~TestDefaultModelManager() override = default;
 
   void GetAllSegmentInfoFromDefaultModel(
-      const std::vector<OptimizationTarget>& segment_ids,
+      const std::vector<SegmentId>& segment_ids,
       MultipleSegmentInfoCallback callback) override {
     base::ThreadTaskRunnerHandle::Get()->PostTask(
         FROM_HERE, base::BindOnce(std::move(callback),
@@ -75,7 +73,7 @@
   }
 
   void GetAllSegmentInfoFromBothModels(
-      const std::vector<OptimizationTarget>& segment_ids,
+      const std::vector<SegmentId>& segment_ids,
       SegmentInfoDatabase* segment_database,
       MultipleSegmentInfoCallback callback) override {
     segment_database->GetSegmentInfoForSegments(
@@ -106,9 +104,9 @@
     base::SetRecordActionTaskRunner(
         task_environment_.GetMainThreadTaskRunner());
 
-    std::vector<OptimizationTarget> segment_ids(
-        {OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-         OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE});
+    std::vector<SegmentId> segment_ids(
+        {SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
+         SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE});
     user_action_signal_handler_ =
         std::make_unique<MockUserActionSignalHandler>();
     histogram_signal_handler_ = std::make_unique<MockHistogramSignalHandler>();
@@ -144,12 +142,12 @@
 TEST_F(SignalFilterProcessorTest, UserActionRegistrationFlow) {
   std::string kUserActionName1 = "some_action_1";
   segment_database_->AddUserActionFeature(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-      kUserActionName1, 0, 0, proto::Aggregation::COUNT);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, kUserActionName1, 0,
+      0, proto::Aggregation::COUNT);
   std::string kUserActionName2 = "some_action_2";
   segment_database_->AddUserActionFeature(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
-      kUserActionName2, 0, 0, proto::Aggregation::COUNT);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE, kUserActionName2, 0, 0,
+      proto::Aggregation::COUNT);
 
   std::set<uint64_t> actions;
   EXPECT_CALL(*user_action_signal_handler_, SetRelevantUserActions(_))
@@ -165,16 +163,16 @@
 TEST_F(SignalFilterProcessorTest, HistogramRegistrationFlow) {
   std::string kHistogramName1 = "some_histogram_1";
   segment_database_->AddHistogramValueFeature(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-      kHistogramName1, 1, 1, proto::Aggregation::COUNT);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, kHistogramName1, 1,
+      1, proto::Aggregation::COUNT);
   std::string kHistogramName2 = "some_histogram_2";
   segment_database_->AddHistogramValueFeature(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
-      kHistogramName2, 1, 1, proto::Aggregation::COUNT);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE, kHistogramName2, 1, 1,
+      proto::Aggregation::COUNT);
   std::string kHistogramName3 = "some_histogram_3";
   segment_database_->AddHistogramEnumFeature(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
-      kHistogramName3, 1, 1, proto::Aggregation::COUNT, {3, 4});
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE, kHistogramName3, 1, 1,
+      proto::Aggregation::COUNT, {3, 4});
 
   std::set<std::pair<std::string, proto::SignalType>> histograms;
   EXPECT_CALL(*histogram_signal_handler_, SetRelevantHistograms(_))
@@ -195,8 +193,8 @@
 }
 
 TEST_F(SignalFilterProcessorTest, UkmMetricsConfig) {
-  const OptimizationTarget kSegmentId =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
+  const SegmentId kSegmentId =
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
   const UkmEventHash kEvent1 = TestEvent(10);
   const UkmEventHash kEvent2 = TestEvent(11);
   const std::array<UkmMetricHash, 3> kMetrics1_1{
@@ -260,8 +258,7 @@
         EXPECT_EQ(actual_config, config2);
       }));
   EXPECT_CALL(*history_observer_,
-              SetHistoryBasedSegments(
-                  base::flat_set<OptimizationTarget>({kSegmentId})));
+              SetHistoryBasedSegments(base::flat_set<SegmentId>({kSegmentId})));
   EXPECT_CALL(*signal_storage_config_, OnSignalCollectionStarted(_));
   signal_filter_processor_->OnSignalListUpdated();
 }
diff --git a/components/segmentation_platform/internal/signals/signal_handler.cc b/components/segmentation_platform/internal/signals/signal_handler.cc
index 4a99703..42e9a80 100644
--- a/components/segmentation_platform/internal/signals/signal_handler.cc
+++ b/components/segmentation_platform/internal/signals/signal_handler.cc
@@ -16,12 +16,10 @@
 SignalHandler::SignalHandler() = default;
 SignalHandler::~SignalHandler() = default;
 
-void SignalHandler::Initialize(
-    StorageService* storage_service,
-    history::HistoryService* history_service,
-    const std::vector<optimization_guide::proto::OptimizationTarget>&
-        segment_ids,
-    base::RepeatingClosure models_refresh_callback) {
+void SignalHandler::Initialize(StorageService* storage_service,
+                               history::HistoryService* history_service,
+                               const std::vector<proto::SegmentId>& segment_ids,
+                               base::RepeatingClosure models_refresh_callback) {
   user_action_signal_handler_ = std::make_unique<UserActionSignalHandler>(
       storage_service->signal_database());
   histogram_signal_handler_ = std::make_unique<HistogramSignalHandler>(
diff --git a/components/segmentation_platform/internal/signals/signal_handler.h b/components/segmentation_platform/internal/signals/signal_handler.h
index 77072aa..ac05648 100644
--- a/components/segmentation_platform/internal/signals/signal_handler.h
+++ b/components/segmentation_platform/internal/signals/signal_handler.h
@@ -8,7 +8,7 @@
 #include <memory>
 
 #include "base/callback.h"
-#include "components/optimization_guide/proto/models.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 
 namespace history {
 class HistoryService;
@@ -33,12 +33,10 @@
   SignalHandler(SignalHandler&) = delete;
   SignalHandler& operator=(SignalHandler&) = delete;
 
-  void Initialize(
-      StorageService* storage_service,
-      history::HistoryService* history_service,
-      const std::vector<optimization_guide::proto::OptimizationTarget>&
-          segment_ids,
-      base::RepeatingClosure model_refresh_callback);
+  void Initialize(StorageService* storage_service,
+                  history::HistoryService* history_service,
+                  const std::vector<proto::SegmentId>& segment_ids,
+                  base::RepeatingClosure model_refresh_callback);
 
   void TearDown();
 
diff --git a/components/segmentation_platform/internal/stats.cc b/components/segmentation_platform/internal/stats.cc
index d8295d8..07a91488 100644
--- a/components/segmentation_platform/internal/stats.cc
+++ b/components/segmentation_platform/internal/stats.cc
@@ -9,9 +9,9 @@
 #include "base/notreached.h"
 #include "base/strings/strcat.h"
 #include "base/strings/string_util.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/proto/types.pb.h"
 #include "components/segmentation_platform/public/config.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 
 namespace segmentation_platform::stats {
 namespace {
@@ -27,7 +27,7 @@
 };
 
 // This is the segmentation subset of
-// optimization_guide::proto::OptimizationTarget.
+// proto::SegmentId.
 // Keep in sync with SegmentationPlatformSegmenationModel in
 // //tools/metrics/histograms/enums.xml.
 // See also SegmentationModel variant in
@@ -46,15 +46,15 @@
 };
 
 AdaptiveToolbarButtonVariant OptimizationTargetToAdaptiveToolbarButtonVariant(
-    OptimizationTarget segment_id) {
+    SegmentId segment_id) {
   switch (segment_id) {
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
       return AdaptiveToolbarButtonVariant::kNewTab;
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
       return AdaptiveToolbarButtonVariant::kShare;
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
       return AdaptiveToolbarButtonVariant::kVoice;
-    case OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN:
+    case SegmentId::OPTIMIZATION_TARGET_UNKNOWN:
       return AdaptiveToolbarButtonVariant::kNone;
     default:
       return AdaptiveToolbarButtonVariant::kUnknown;
@@ -70,70 +70,68 @@
          segmentation_key == kFeedUserSegmentationKey;
 }
 
-BooleanSegmentSwitch GetBooleanSegmentSwitch(
-    OptimizationTarget new_selection,
-    OptimizationTarget previous_selection) {
-  if (new_selection != OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN &&
-      previous_selection == OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN) {
+BooleanSegmentSwitch GetBooleanSegmentSwitch(SegmentId new_selection,
+                                             SegmentId previous_selection) {
+  if (new_selection != SegmentId::OPTIMIZATION_TARGET_UNKNOWN &&
+      previous_selection == SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
     return BooleanSegmentSwitch::kNoneToEnabled;
-  } else if (new_selection == OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN &&
-             previous_selection !=
-                 OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN) {
+  } else if (new_selection == SegmentId::OPTIMIZATION_TARGET_UNKNOWN &&
+             previous_selection != SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
     return BooleanSegmentSwitch::kEnabledToNone;
   }
   return BooleanSegmentSwitch::kUnknown;
 }
 
 AdaptiveToolbarSegmentSwitch GetAdaptiveToolbarSegmentSwitch(
-    OptimizationTarget new_selection,
-    OptimizationTarget previous_selection) {
+    SegmentId new_selection,
+    SegmentId previous_selection) {
   switch (previous_selection) {
-    case OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN:
+    case SegmentId::OPTIMIZATION_TARGET_UNKNOWN:
       switch (new_selection) {
-        case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
+        case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
           return AdaptiveToolbarSegmentSwitch::kNoneToNewTab;
-        case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
+        case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
           return AdaptiveToolbarSegmentSwitch::kNoneToShare;
-        case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
+        case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
           return AdaptiveToolbarSegmentSwitch::kNoneToVoice;
         default:
           NOTREACHED();
           return AdaptiveToolbarSegmentSwitch::kUnknown;
       }
 
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
       switch (new_selection) {
-        case OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN:
+        case SegmentId::OPTIMIZATION_TARGET_UNKNOWN:
           return AdaptiveToolbarSegmentSwitch::kNewTabToNone;
-        case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
+        case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
           return AdaptiveToolbarSegmentSwitch::kNewTabToShare;
-        case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
+        case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
           return AdaptiveToolbarSegmentSwitch::kNewTabToVoice;
         default:
           NOTREACHED();
           return AdaptiveToolbarSegmentSwitch::kUnknown;
       }
 
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
       switch (new_selection) {
-        case OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN:
+        case SegmentId::OPTIMIZATION_TARGET_UNKNOWN:
           return AdaptiveToolbarSegmentSwitch::kShareToNone;
-        case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
+        case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
           return AdaptiveToolbarSegmentSwitch::kShareToNewTab;
-        case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
+        case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
           return AdaptiveToolbarSegmentSwitch::kShareToVoice;
         default:
           NOTREACHED();
           return AdaptiveToolbarSegmentSwitch::kUnknown;
       }
 
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
       switch (new_selection) {
-        case OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN:
+        case SegmentId::OPTIMIZATION_TARGET_UNKNOWN:
           return AdaptiveToolbarSegmentSwitch::kVoiceToNone;
-        case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
+        case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
           return AdaptiveToolbarSegmentSwitch::kVoiceToNewTab;
-        case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
+        case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
           return AdaptiveToolbarSegmentSwitch::kVoiceToShare;
         default:
           NOTREACHED();
@@ -146,26 +144,23 @@
   }
 }
 
-SegmentationModel OptimizationTargetToSegmentationModel(
-    OptimizationTarget segment_id) {
+SegmentationModel OptimizationTargetToSegmentationModel(SegmentId segment_id) {
   switch (segment_id) {
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
       return SegmentationModel::kNewTab;
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
       return SegmentationModel::kShare;
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
       return SegmentationModel::kVoice;
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY:
       return SegmentationModel::kDummy;
-    case OptimizationTarget::
-        OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID:
       return SegmentationModel::kChromeStartAndroid;
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES:
       return SegmentationModel::kQueryTiles;
-    case OptimizationTarget::
-        OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT:
       return SegmentationModel::kChromeLowUserEngagement;
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_FEED_USER:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_FEED_USER:
       return SegmentationModel::kFeedUserSegment;
     default:
       return SegmentationModel::kUnknown;
@@ -226,26 +221,23 @@
 
 // Should map to SegmentationModel variant in
 // //tools/metrics/histograms/metadata/segmentation_platform/histograms.xml.
-std::string OptimizationTargetToHistogramVariant(
-    OptimizationTarget segment_id) {
+std::string OptimizationTargetToHistogramVariant(SegmentId segment_id) {
   switch (segment_id) {
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
       return "NewTab";
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
       return "Share";
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
       return "Voice";
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY:
       return "Dummy";
-    case OptimizationTarget::
-        OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID:
       return "ChromeStartAndroid";
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES:
       return "QueryTiles";
-    case OptimizationTarget::
-        OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT:
       return "ChromeLowUserEngagement";
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_FEED_USER:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_FEED_USER:
       return "FeedUserSegment";
     default:
       return "Other";
@@ -274,13 +266,13 @@
   return "Unknown";
 }
 
-void RecordModelScore(OptimizationTarget segment_id, float score) {
+void RecordModelScore(SegmentId segment_id, float score) {
   // Special case adaptive toolbar models since it already has histograms being
   // recorded and updating names will affect current work.
   switch (segment_id) {
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
       base::UmaHistogramPercentage(
           "SegmentationPlatform.AdaptiveToolbar.ModelScore." +
               OptimizationTargetToHistogramVariant(segment_id),
@@ -291,16 +283,14 @@
   }
 
   switch (segment_id) {
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY:
-    case OptimizationTarget::
-        OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID:
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES:
-    case OptimizationTarget::
-        OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT:
-    case OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_FEED_USER:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT:
+    case SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_FEED_USER:
       // Assumes all models return score between 0 and 1. This is true for all
       // the models we have currently.
       base::UmaHistogramPercentage(
@@ -315,8 +305,8 @@
 
 void RecordSegmentSelectionComputed(
     const std::string& segmentation_key,
-    OptimizationTarget new_selection,
-    absl::optional<OptimizationTarget> previous_selection) {
+    SegmentId new_selection,
+    absl::optional<SegmentId> previous_selection) {
   // Special case adaptive toolbar since it already has histograms being
   // recorded and updating names will affect current work.
   if (segmentation_key == kAdaptiveToolbarSegmentationKey) {
@@ -330,10 +320,9 @@
   base::UmaHistogramEnumeration(
       computed_hist, OptimizationTargetToSegmentationModel(new_selection));
 
-  OptimizationTarget prev_segment =
-      previous_selection.has_value()
-          ? previous_selection.value()
-          : OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN;
+  SegmentId prev_segment = previous_selection.has_value()
+                               ? previous_selection.value()
+                               : SegmentId::OPTIMIZATION_TARGET_UNKNOWN;
 
   if (prev_segment == new_selection)
     return;
@@ -371,15 +360,14 @@
       "SegmentationPlatform.Maintenance.SignalIdentifierCount", count);
 }
 
-void RecordModelDeliveryHasMetadata(OptimizationTarget segment_id,
-                                    bool has_metadata) {
+void RecordModelDeliveryHasMetadata(SegmentId segment_id, bool has_metadata) {
   base::UmaHistogramBoolean(
       "SegmentationPlatform.ModelDelivery.HasMetadata." +
           OptimizationTargetToHistogramVariant(segment_id),
       has_metadata);
 }
 
-void RecordModelDeliveryMetadataFeatureCount(OptimizationTarget segment_id,
+void RecordModelDeliveryMetadataFeatureCount(SegmentId segment_id,
                                              size_t count) {
   base::UmaHistogramCounts1000(
       "SegmentationPlatform.ModelDelivery.Metadata.FeatureCount." +
@@ -388,7 +376,7 @@
 }
 
 void RecordModelDeliveryMetadataValidation(
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     bool processed,
     metadata_utils::ValidationResult validation_result) {
   // Should map to ValidationPhase variant string in
@@ -401,37 +389,34 @@
       validation_result);
 }
 
-void RecordModelDeliveryReceived(OptimizationTarget segment_id) {
+void RecordModelDeliveryReceived(SegmentId segment_id) {
   UMA_HISTOGRAM_ENUMERATION("SegmentationPlatform.ModelDelivery.Received",
                             OptimizationTargetToSegmentationModel(segment_id));
 }
 
-void RecordModelDeliverySaveResult(OptimizationTarget segment_id,
-                                   bool success) {
+void RecordModelDeliverySaveResult(SegmentId segment_id, bool success) {
   base::UmaHistogramBoolean(
       "SegmentationPlatform.ModelDelivery.SaveResult." +
           OptimizationTargetToHistogramVariant(segment_id),
       success);
 }
 
-void RecordModelDeliverySegmentIdMatches(OptimizationTarget segment_id,
-                                         bool matches) {
+void RecordModelDeliverySegmentIdMatches(SegmentId segment_id, bool matches) {
   base::UmaHistogramBoolean(
       "SegmentationPlatform.ModelDelivery.SegmentIdMatches." +
           OptimizationTargetToHistogramVariant(segment_id),
       matches);
 }
 
-void RecordModelExecutionDurationFeatureProcessing(
-    OptimizationTarget segment_id,
-    base::TimeDelta duration) {
+void RecordModelExecutionDurationFeatureProcessing(SegmentId segment_id,
+                                                   base::TimeDelta duration) {
   base::UmaHistogramTimes(
       "SegmentationPlatform.ModelExecution.Duration.FeatureProcessing." +
           OptimizationTargetToHistogramVariant(segment_id),
       duration);
 }
 
-void RecordModelExecutionDurationModel(OptimizationTarget segment_id,
+void RecordModelExecutionDurationModel(SegmentId segment_id,
                                        bool success,
                                        base::TimeDelta duration) {
   ModelExecutionStatus status = success ? ModelExecutionStatus::kSuccess
@@ -447,7 +432,7 @@
       duration);
 }
 
-void RecordModelExecutionDurationTotal(OptimizationTarget segment_id,
+void RecordModelExecutionDurationTotal(SegmentId segment_id,
                                        ModelExecutionStatus status,
                                        base::TimeDelta duration) {
   absl::optional<base::StringPiece> status_variant =
@@ -461,22 +446,21 @@
       duration);
 }
 
-void RecordModelExecutionResult(OptimizationTarget segment_id, float result) {
+void RecordModelExecutionResult(SegmentId segment_id, float result) {
   base::UmaHistogramPercentage(
       "SegmentationPlatform.ModelExecution.Result." +
           OptimizationTargetToHistogramVariant(segment_id),
       result * 100);
 }
 
-void RecordModelExecutionSaveResult(OptimizationTarget segment_id,
-                                    bool success) {
+void RecordModelExecutionSaveResult(SegmentId segment_id, bool success) {
   base::UmaHistogramBoolean(
       "SegmentationPlatform.ModelExecution.SaveResult." +
           OptimizationTargetToHistogramVariant(segment_id),
       success);
 }
 
-void RecordModelExecutionStatus(OptimizationTarget segment_id,
+void RecordModelExecutionStatus(SegmentId segment_id,
                                 bool default_provider,
                                 ModelExecutionStatus status) {
   if (!default_provider) {
@@ -492,7 +476,7 @@
   }
 }
 
-void RecordModelExecutionZeroValuePercent(OptimizationTarget segment_id,
+void RecordModelExecutionZeroValuePercent(SegmentId segment_id,
                                           const std::vector<float>& tensor) {
   base::UmaHistogramPercentage(
       "SegmentationPlatform.ModelExecution.ZeroValuePercent." +
@@ -551,7 +535,7 @@
       reason);
 }
 
-void RecordModelAvailability(OptimizationTarget segment_id,
+void RecordModelAvailability(SegmentId segment_id,
                              SegmentationModelAvailability availability) {
   base::UmaHistogramEnumeration(
       "SegmentationPlatform.ModelAvailability." +
@@ -565,7 +549,7 @@
       tensor_size);
 }
 
-void RecordTrainingDataCollectionEvent(OptimizationTarget segment_id,
+void RecordTrainingDataCollectionEvent(SegmentId segment_id,
                                        TrainingDataCollectionEvent event) {
   base::UmaHistogramEnumeration(
       "SegmentationPlatform.TrainingDataCollectionEvents." +
diff --git a/components/segmentation_platform/internal/stats.h b/components/segmentation_platform/internal/stats.h
index 3a468da..ac59117d 100644
--- a/components/segmentation_platform/internal/stats.h
+++ b/components/segmentation_platform/internal/stats.h
@@ -5,17 +5,17 @@
 #ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_STATS_H_
 #define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_STATS_H_
 
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/execution/model_execution_status.h"
 #include "components/segmentation_platform/internal/metadata/metadata_utils.h"
 #include "components/segmentation_platform/internal/proto/types.pb.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "components/segmentation_platform/public/segment_selection_result.h"
 #include "third_party/abseil-cpp/absl/types/optional.h"
 
-using optimization_guide::proto::OptimizationTarget;
-
 namespace segmentation_platform::stats {
 
+using proto::SegmentId;
+
 // Keep in sync with AdaptiveToolbarSegmentSwitch in enums.xml.
 // Visible for testing.
 enum class AdaptiveToolbarSegmentSwitch {
@@ -45,20 +45,20 @@
 };
 
 // Returns an UMA display string for the given segment_id.
-std::string OptimizationTargetToHistogramVariant(OptimizationTarget segment_id);
+std::string OptimizationTargetToHistogramVariant(SegmentId segment_id);
 
 // Returns an UMA display string for the given `segmentation_key`.
 const char* SegmentationKeyToUmaName(const std::string& segmentation_key);
 
 // Records the score computed for a given segment.
-void RecordModelScore(OptimizationTarget segment_id, float score);
+void RecordModelScore(SegmentId segment_id, float score);
 
 // Records the result of segment selection whenever segment selection is
 // computed.
 void RecordSegmentSelectionComputed(
     const std::string& segmentation_key,
-    OptimizationTarget new_selection,
-    absl::optional<OptimizationTarget> previous_selection);
+    SegmentId new_selection,
+    absl::optional<SegmentId> previous_selection);
 
 // Database Maintenance metrics.
 // Records the number of unique signal identifiers that were successfully
@@ -74,61 +74,57 @@
 // Model Delivery metrics.
 // Records whether any incoming ML model had metadata attached that we were able
 // to parse.
-void RecordModelDeliveryHasMetadata(OptimizationTarget segment_id,
-                                    bool has_metadata);
+void RecordModelDeliveryHasMetadata(SegmentId segment_id, bool has_metadata);
 // Records the number of tensor features an updated ML model has.
-void RecordModelDeliveryMetadataFeatureCount(OptimizationTarget segment_id,
+void RecordModelDeliveryMetadataFeatureCount(SegmentId segment_id,
                                              size_t count);
 // Records the result of validating the metadata of an incoming ML model.
 // Recorded before and after it has been merged with the already stored
 // metadata.
 void RecordModelDeliveryMetadataValidation(
-    OptimizationTarget segment_id,
+    SegmentId segment_id,
     bool processed,
     metadata_utils::ValidationResult validation_result);
 // Record what type of model metadata we received.
-void RecordModelDeliveryReceived(OptimizationTarget segment_id);
+void RecordModelDeliveryReceived(SegmentId segment_id);
 // Records the result of attempting to save an updated version of the model
 // metadata.
-void RecordModelDeliverySaveResult(OptimizationTarget segment_id, bool success);
+void RecordModelDeliverySaveResult(SegmentId segment_id, bool success);
 // Records whether the currently stored segment_id matches the incoming
 // segment_id, as these are expected to match.
-void RecordModelDeliverySegmentIdMatches(OptimizationTarget segment_id,
-                                         bool matches);
+void RecordModelDeliverySegmentIdMatches(SegmentId segment_id, bool matches);
 
 // Model Execution metrics.
 // Records the duration of processing a single ML feature. This only takes into
 // account the time it takes to process (aggregate) a feature result, not
 // fetching it from the database. It also takes into account filtering any
 // enum histograms.
-void RecordModelExecutionDurationFeatureProcessing(
-    OptimizationTarget segment_id,
-    base::TimeDelta duration);
+void RecordModelExecutionDurationFeatureProcessing(SegmentId segment_id,
+                                                   base::TimeDelta duration);
 // Records the duration of executing an ML model. This only takes into account
 // the time it takes to invoke and wait for a result from the underlying ML
 // infrastructure from //components/optimization_guide, and not fetching the
 // relevant data from the database.
-void RecordModelExecutionDurationModel(OptimizationTarget segment_id,
+void RecordModelExecutionDurationModel(SegmentId segment_id,
                                        bool success,
                                        base::TimeDelta duration);
 // Records the duration of fetching data for, processing, and executing an ML
 // model.
-void RecordModelExecutionDurationTotal(OptimizationTarget segment_id,
+void RecordModelExecutionDurationTotal(SegmentId segment_id,
                                        ModelExecutionStatus status,
                                        base::TimeDelta duration);
 // Records the result value after successfully executing an ML model.
-void RecordModelExecutionResult(OptimizationTarget segment_id, float result);
+void RecordModelExecutionResult(SegmentId segment_id, float result);
 // Records whether the result value of of executing an ML model was successfully
 // saved.
-void RecordModelExecutionSaveResult(OptimizationTarget segment_id,
-                                    bool success);
+void RecordModelExecutionSaveResult(SegmentId segment_id, bool success);
 // Records the final execution status for any ML model execution.
-void RecordModelExecutionStatus(OptimizationTarget segment_id,
+void RecordModelExecutionStatus(SegmentId segment_id,
                                 bool default_provider,
                                 ModelExecutionStatus status);
 // Records the percent of features in a tensor that are equal to 0 when the
 // segmentation model is executed.
-void RecordModelExecutionZeroValuePercent(OptimizationTarget segment_id,
+void RecordModelExecutionZeroValuePercent(SegmentId segment_id,
                                           const std::vector<float>& tensor);
 
 // Signal Database metrics.
@@ -188,7 +184,7 @@
   kMaxValue = kMetadataInvalid
 };
 // Records the availability of segmentation models for each target needed.
-void RecordModelAvailability(OptimizationTarget segment_id,
+void RecordModelAvailability(SegmentId segment_id,
                              SegmentationModelAvailability availability);
 
 // Records the number of input tensor that's causing a failure to upload
@@ -212,7 +208,7 @@
 };
 
 // Records analytics for training data collection.
-void RecordTrainingDataCollectionEvent(OptimizationTarget segment_id,
+void RecordTrainingDataCollectionEvent(SegmentId segment_id,
                                        TrainingDataCollectionEvent event);
 
 }  // namespace segmentation_platform::stats
diff --git a/components/segmentation_platform/internal/stats_unittest.cc b/components/segmentation_platform/internal/stats_unittest.cc
index 722aa165..cc8d085 100644
--- a/components/segmentation_platform/internal/stats_unittest.cc
+++ b/components/segmentation_platform/internal/stats_unittest.cc
@@ -5,9 +5,9 @@
 #include "components/segmentation_platform/internal/stats.h"
 
 #include "base/test/metrics/histogram_tester.h"
-#include "components/optimization_guide/proto/models.pb.h"
 #include "components/segmentation_platform/internal/proto/types.pb.h"
 #include "components/segmentation_platform/public/config.h"
+#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
 #include "testing/gmock/include/gmock/gmock.h"
 #include "testing/gtest/include/gtest/gtest.h"
 
@@ -25,44 +25,40 @@
   std::vector<float> all_non_zero{1, 2, 3};
 
   RecordModelExecutionZeroValuePercent(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, empty);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, empty);
   EXPECT_EQ(
       1, tester.GetBucketCount(
              "SegmentationPlatform.ModelExecution.ZeroValuePercent.NewTab", 0));
 
   RecordModelExecutionZeroValuePercent(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-      single_zero);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, single_zero);
   EXPECT_EQ(
       1,
       tester.GetBucketCount(
           "SegmentationPlatform.ModelExecution.ZeroValuePercent.NewTab", 100));
 
   RecordModelExecutionZeroValuePercent(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-      single_non_zero);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, single_non_zero);
   EXPECT_EQ(
       2, tester.GetBucketCount(
              "SegmentationPlatform.ModelExecution.ZeroValuePercent.NewTab", 0));
 
   RecordModelExecutionZeroValuePercent(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, all_zeroes);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, all_zeroes);
   EXPECT_EQ(
       2,
       tester.GetBucketCount(
           "SegmentationPlatform.ModelExecution.ZeroValuePercent.NewTab", 100));
 
   RecordModelExecutionZeroValuePercent(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-      one_non_zero);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, one_non_zero);
   EXPECT_EQ(
       1,
       tester.GetBucketCount(
           "SegmentationPlatform.ModelExecution.ZeroValuePercent.NewTab", 66));
 
   RecordModelExecutionZeroValuePercent(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-      all_non_zero);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, all_non_zero);
   EXPECT_EQ(
       3, tester.GetBucketCount(
              "SegmentationPlatform.ModelExecution.ZeroValuePercent.NewTab", 0));
@@ -75,20 +71,19 @@
   // Share -> New tab.
   RecordSegmentSelectionComputed(
       kAdaptiveToolbarSegmentationKey,
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
 
   // None -> Share.
   RecordSegmentSelectionComputed(
       kAdaptiveToolbarSegmentationKey,
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
-      absl::nullopt);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE, absl::nullopt);
 
   // Share -> Share.
   RecordSegmentSelectionComputed(
       kAdaptiveToolbarSegmentationKey,
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
   tester.ExpectTotalCount(histogram, 2);
 
   EXPECT_THAT(
@@ -111,9 +106,8 @@
   // Start to none.
   RecordSegmentSelectionComputed(
       kChromeStartAndroidSegmentationKey,
-      OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN,
-      OptimizationTarget::
-          OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID);
+      SegmentId::OPTIMIZATION_TARGET_UNKNOWN,
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID);
 
   tester.ExpectTotalCount(histogram, 1);
   EXPECT_THAT(tester.GetAllSamples(histogram),
@@ -122,7 +116,7 @@
   // None to start.
   RecordSegmentSelectionComputed(
       kChromeStartAndroidSegmentationKey,
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID,
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID,
       absl::nullopt);
 
   tester.ExpectTotalCount(histogram, 2);
@@ -164,7 +158,7 @@
 TEST(StatsTest, TrainingDataCollectionEvent) {
   base::HistogramTester tester;
   RecordTrainingDataCollectionEvent(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
       TrainingDataCollectionEvent::kImmediateCollectionStart);
   EXPECT_EQ(1,
             tester.GetBucketCount(
diff --git a/components/segmentation_platform/internal/ukm_data_manager_impl_unittest.cc b/components/segmentation_platform/internal/ukm_data_manager_impl_unittest.cc
index e944d56..bbaea34 100644
--- a/components/segmentation_platform/internal/ukm_data_manager_impl_unittest.cc
+++ b/components/segmentation_platform/internal/ukm_data_manager_impl_unittest.cc
@@ -113,11 +113,10 @@
   }
 
   void AddModel(const proto::SegmentationModelMetadata& metadata) {
-    auto& callback =
-        model_provider_data_.model_providers_callbacks
-            [OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE];
-    callback.Run(OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
-                 metadata, 0);
+    auto& callback = model_provider_data_.model_providers_callbacks
+                         [SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE];
+    callback.Run(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE, metadata,
+                 0);
     segment_db_->GetCallback(true);
     segment_db_->UpdateCallback(true);
     segment_db_->LoadCallback(true);
@@ -128,7 +127,7 @@
     return *segmentation_platform_service_impl_;
   }
 
-  void SaveSegmentResult(OptimizationTarget segment_id,
+  void SaveSegmentResult(SegmentId segment_id,
                          absl::optional<proto::PredictionResult> result) {
     const std::string key = base::NumberToString(static_cast<int>(segment_id));
     auto& segment_info = segment_db_entries_[key];
@@ -142,7 +141,7 @@
     }
   }
 
-  bool HasSegmentResult(OptimizationTarget segment_id) {
+  bool HasSegmentResult(SegmentId segment_id) {
     const std::string key = base::NumberToString(static_cast<int>(segment_id));
     const auto it = segment_db_entries_.find(key);
     if (it == segment_db_entries_.end())
@@ -218,8 +217,8 @@
 
 TEST_F(UkmDataManagerImplTest, HistoryNotification) {
   const GURL kUrl1 = GURL("https://www.url1.com/");
-  const OptimizationTarget kSegmentId =
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
+  const SegmentId kSegmentId =
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
 
   TestServicesForPlatform& platform1 = CreatePlatform();
   platform1.AddModel(PageLoadModelMetadata());
@@ -253,7 +252,7 @@
 
   // History based segment results should be removed.
   EXPECT_FALSE(platform1.HasSegmentResult(
-      OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE));
+      SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE));
 
   RemovePlatform(&platform1);
 }