blob: 62cf6a7f24ae581b84caee3a716108591ee3ac78 [file] [log] [blame]
Avi Drissman8ba1bad2022-09-13 19:22:361// Copyright 2022 The Chromium Authors
Min Qin75bee1b2022-02-05 03:57:342// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
6
7#include "base/bit_cast.h"
8#include "base/metrics/field_trial_params.h"
9#include "base/strings/string_number_conversions.h"
10#include "base/strings/string_split.h"
Min Qincae984cb2022-05-20 03:08:3411#include "base/time/clock.h"
12#include "components/segmentation_platform/internal/constants.h"
Min Qin642ab222022-05-19 21:54:5313#include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
Min Qin75bee1b2022-02-05 03:57:3414#include "components/segmentation_platform/internal/stats.h"
15#include "components/segmentation_platform/public/config.h"
16#include "components/segmentation_platform/public/features.h"
Min Qincae984cb2022-05-20 03:08:3417#include "components/segmentation_platform/public/local_state_helper.h"
Min Qin75bee1b2022-02-05 03:57:3418#include "services/metrics/public/cpp/ukm_builders.h"
19#include "services/metrics/public/cpp/ukm_recorder.h"
20
21#define CALL_MEMBER_FN(obj, func) ((obj).*(func))
22#define ARRAY_SIZE(x) (sizeof(x) / sizeof(x)[0])
23
ssid8386fc72022-05-21 00:34:1724using segmentation_platform::proto::SegmentId;
Min Qin75bee1b2022-02-05 03:57:3425using ukm::builders::Segmentation_ModelExecution;
26
27namespace {
28using UkmMemberFn =
29 Segmentation_ModelExecution& (Segmentation_ModelExecution::*)(int64_t);
30
Xing Liu282dd0d2022-02-24 00:20:3931const UkmMemberFn kSegmentationUkmInputMethods[] = {
Min Qin75bee1b2022-02-05 03:57:3432 &Segmentation_ModelExecution::SetInput0,
33 &Segmentation_ModelExecution::SetInput1,
34 &Segmentation_ModelExecution::SetInput2,
35 &Segmentation_ModelExecution::SetInput3,
36 &Segmentation_ModelExecution::SetInput4,
37 &Segmentation_ModelExecution::SetInput5,
38 &Segmentation_ModelExecution::SetInput6,
39 &Segmentation_ModelExecution::SetInput7,
40 &Segmentation_ModelExecution::SetInput8,
41 &Segmentation_ModelExecution::SetInput9,
42 &Segmentation_ModelExecution::SetInput10,
43 &Segmentation_ModelExecution::SetInput11,
44 &Segmentation_ModelExecution::SetInput12,
45 &Segmentation_ModelExecution::SetInput13,
46 &Segmentation_ModelExecution::SetInput14,
47 &Segmentation_ModelExecution::SetInput15,
48 &Segmentation_ModelExecution::SetInput16,
49 &Segmentation_ModelExecution::SetInput17,
50 &Segmentation_ModelExecution::SetInput18,
51 &Segmentation_ModelExecution::SetInput19,
52 &Segmentation_ModelExecution::SetInput20,
53 &Segmentation_ModelExecution::SetInput21,
54 &Segmentation_ModelExecution::SetInput22,
55 &Segmentation_ModelExecution::SetInput23,
56 &Segmentation_ModelExecution::SetInput24,
57 &Segmentation_ModelExecution::SetInput25,
58 &Segmentation_ModelExecution::SetInput26,
59 &Segmentation_ModelExecution::SetInput27,
60 &Segmentation_ModelExecution::SetInput28,
61 &Segmentation_ModelExecution::SetInput29};
62
Xing Liu282dd0d2022-02-24 00:20:3963const UkmMemberFn kSegmentationUkmOutputMethods[] = {
64 &Segmentation_ModelExecution::SetActualResult,
65 &Segmentation_ModelExecution::SetActualResult2,
66 &Segmentation_ModelExecution::SetActualResult3,
67 &Segmentation_ModelExecution::SetActualResult4,
68 &Segmentation_ModelExecution::SetActualResult5,
69 &Segmentation_ModelExecution::SetActualResult6};
70
ssid8386fc72022-05-21 00:34:1771base::flat_set<SegmentId> GetSegmentIdsAllowedForReporting() {
Min Qin75bee1b2022-02-05 03:57:3472 std::vector<std::string> segment_ids = base::SplitString(
73 base::GetFieldTrialParamValueByFeature(
74 segmentation_platform::features::
75 kSegmentationStructuredMetricsFeature,
76 segmentation_platform::kSegmentIdsAllowedForReportingKey),
77 ",;", base::WhitespaceHandling::TRIM_WHITESPACE,
78 base::SplitResult::SPLIT_WANT_NONEMPTY);
ssid8386fc72022-05-21 00:34:1779 base::flat_set<SegmentId> result;
Min Qin75bee1b2022-02-05 03:57:3480 for (const auto& id : segment_ids) {
81 int segment_id;
82 if (base::StringToInt(id, &segment_id))
ssid8386fc72022-05-21 00:34:1783 result.emplace(static_cast<SegmentId>(segment_id));
Min Qin75bee1b2022-02-05 03:57:3484 }
85 return result;
86}
87
88} // namespace
89
90namespace segmentation_platform {
91
92SegmentationUkmHelper::SegmentationUkmHelper() {
93 Initialize();
94}
95
96SegmentationUkmHelper::~SegmentationUkmHelper() = default;
97
98void SegmentationUkmHelper::Initialize() {
99 allowed_segment_ids_ = GetSegmentIdsAllowedForReporting();
100}
101
102// static
103SegmentationUkmHelper* SegmentationUkmHelper::GetInstance() {
104 static base::NoDestructor<SegmentationUkmHelper> helper;
105 return helper.get();
106}
107
108ukm::SourceId SegmentationUkmHelper::RecordModelExecutionResult(
ssid8386fc72022-05-21 00:34:17109 SegmentId segment_id,
Min Qin75bee1b2022-02-05 03:57:34110 int64_t model_version,
111 const std::vector<float>& input_tensor,
112 float result) {
Min Qin75bee1b2022-02-05 03:57:34113 ukm::SourceId source_id = ukm::NoURLSourceId();
Min Qin75bee1b2022-02-05 03:57:34114 ukm::builders::Segmentation_ModelExecution execution_result(source_id);
115
Xing Liu282dd0d2022-02-24 00:20:39116 // Add inputs to ukm message.
117 if (!AddInputsToUkm(&execution_result, segment_id, model_version,
Min Qinf254aa92022-08-24 18:48:42118 input_tensor)) {
Xing Liu282dd0d2022-02-24 00:20:39119 return ukm::kInvalidSourceId;
Min Qinf254aa92022-08-24 18:48:42120 }
Min Qin75bee1b2022-02-05 03:57:34121
Xing Liu282dd0d2022-02-24 00:20:39122 // TODO(xingliu): Also record continuous outputs for model execution.
Min Qin75bee1b2022-02-05 03:57:34123 execution_result.SetPredictionResult(FloatToInt64(result))
124 .Record(ukm::UkmRecorder::Get());
125 return source_id;
126}
127
Xing Liu282dd0d2022-02-24 00:20:39128ukm::SourceId SegmentationUkmHelper::RecordTrainingData(
ssid8386fc72022-05-21 00:34:17129 SegmentId segment_id,
Xing Liu282dd0d2022-02-24 00:20:39130 int64_t model_version,
131 const std::vector<float>& input_tensor,
132 const std::vector<float>& outputs,
Min Qind52d150c2022-04-22 05:32:13133 const std::vector<int>& output_indexes,
Min Qin642ab222022-05-19 21:54:53134 absl::optional<proto::PredictionResult> prediction_result,
135 absl::optional<SelectedSegment> selected_segment) {
Xing Liu282dd0d2022-02-24 00:20:39136 ukm::SourceId source_id = ukm::NoURLSourceId();
137 ukm::builders::Segmentation_ModelExecution execution_result(source_id);
138 if (!AddInputsToUkm(&execution_result, segment_id, model_version,
139 input_tensor)) {
140 return ukm::kInvalidSourceId;
141 }
142
143 if (!AddOutputsToUkm(&execution_result, outputs, output_indexes)) {
144 return ukm::kInvalidSourceId;
145 }
146
ritikagupeaf525c2022-11-11 00:53:24147 if (prediction_result.has_value() && prediction_result->result_size() > 0) {
148 // TODO(ritikagup): Add support for uploading multiple outputs.
Min Qind52d150c2022-04-22 05:32:13149 execution_result.SetPredictionResult(
ritikagupeaf525c2022-11-11 00:53:24150 FloatToInt64(prediction_result->result()[0]));
Min Qind52d150c2022-04-22 05:32:13151 }
Min Qin642ab222022-05-19 21:54:53152 if (selected_segment.has_value()) {
153 execution_result.SetSelectionResult(selected_segment->segment_id);
154 execution_result.SetOutputDelaySec(
155 (base::Time::Now() - selected_segment->selection_time).InSeconds());
156 }
157
Xing Liu282dd0d2022-02-24 00:20:39158 execution_result.Record(ukm::UkmRecorder::Get());
159 return source_id;
160}
161
162bool SegmentationUkmHelper::AddInputsToUkm(
163 ukm::builders::Segmentation_ModelExecution* ukm_builder,
ssid8386fc72022-05-21 00:34:17164 SegmentId segment_id,
Xing Liu282dd0d2022-02-24 00:20:39165 int64_t model_version,
166 const std::vector<float>& input_tensor) {
Xing Liu282dd0d2022-02-24 00:20:39167 if (input_tensor.size() > ARRAY_SIZE(kSegmentationUkmInputMethods)) {
168 // Don't record UKM if there are too many tensors.
169 stats::RecordTooManyInputTensors(input_tensor.size());
170 return false;
171 }
172
173 ukm_builder->SetOptimizationTarget(segment_id).SetModelVersion(model_version);
174 for (size_t i = 0; i < input_tensor.size(); ++i) {
175 CALL_MEMBER_FN(*ukm_builder, kSegmentationUkmInputMethods[i])
176 (FloatToInt64(input_tensor[i]));
177 }
178 return true;
179}
180
181bool SegmentationUkmHelper::AddOutputsToUkm(
182 ukm::builders::Segmentation_ModelExecution* ukm_builder,
183 const std::vector<float>& outputs,
184 const std::vector<int>& output_indexes) {
185 DCHECK(!outputs.empty());
186 if (outputs.size() != output_indexes.size())
187 return false;
188
189 const int output_methods_size = ARRAY_SIZE(kSegmentationUkmOutputMethods);
190 if (outputs.size() > output_methods_size)
191 return false;
192
193 for (size_t i = 0; i < outputs.size(); ++i) {
194 if (output_indexes[i] >= output_methods_size)
195 return false;
196 CALL_MEMBER_FN(*ukm_builder,
197 kSegmentationUkmOutputMethods[output_indexes[i]])
198 (FloatToInt64(outputs[i]));
199 }
200
201 return true;
202}
203
Min Qinf254aa92022-08-24 18:48:42204bool SegmentationUkmHelper::CanUploadTensors(
205 const proto::SegmentInfo& segment_info) const {
206 if (!base::FeatureList::IsEnabled(
207 features::kSegmentationStructuredMetricsFeature)) {
208 return false;
209 }
210 return segment_info.model_metadata().upload_tensors() ||
211 allowed_segment_ids_.contains(
212 static_cast<int>(segment_info.segment_id()));
213}
214
Min Qin75bee1b2022-02-05 03:57:34215// static
216int64_t SegmentationUkmHelper::FloatToInt64(float f) {
217 // Encode the float number in IEEE754 double precision.
Peter Kastingcc88ac052022-05-03 09:58:01218 return base::bit_cast<int64_t>(static_cast<double>(f));
Min Qin75bee1b2022-02-05 03:57:34219}
220
Min Qincae984cb2022-05-20 03:08:34221// static
222bool SegmentationUkmHelper::AllowedToUploadData(
223 base::TimeDelta signal_storage_length,
224 base::Clock* clock) {
Min Qin0b78c782022-05-21 01:42:13225 base::Time most_recent_allowed = LocalStateHelper::GetInstance().GetPrefTime(
226 kSegmentationUkmMostRecentAllowedTimeKey);
227 // If the local state is never set, return false.
228 if (most_recent_allowed.is_null() ||
229 most_recent_allowed == base::Time::Max()) {
230 return false;
231 }
232 return most_recent_allowed + signal_storage_length < clock->Now();
Min Qincae984cb2022-05-20 03:08:34233}
234
Min Qin75bee1b2022-02-05 03:57:34235} // namespace segmentation_platform