blob: c1c434493807acef022d2bef16f0af18acc77097 [file] [log] [blame]
Min Qin75bee1b2022-02-05 03:57:341// Copyright 2022 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
6
7#include "base/bit_cast.h"
8#include "base/metrics/field_trial_params.h"
9#include "base/strings/string_number_conversions.h"
10#include "base/strings/string_split.h"
11#include "components/segmentation_platform/internal/stats.h"
12#include "components/segmentation_platform/public/config.h"
13#include "components/segmentation_platform/public/features.h"
14#include "services/metrics/public/cpp/ukm_builders.h"
15#include "services/metrics/public/cpp/ukm_recorder.h"
16
17#define CALL_MEMBER_FN(obj, func) ((obj).*(func))
18#define ARRAY_SIZE(x) (sizeof(x) / sizeof(x)[0])
19
20using ukm::builders::Segmentation_ModelExecution;
21
22namespace {
23using UkmMemberFn =
24 Segmentation_ModelExecution& (Segmentation_ModelExecution::*)(int64_t);
25
26const UkmMemberFn kSegmentationUkmMethods[] = {
27 &Segmentation_ModelExecution::SetInput0,
28 &Segmentation_ModelExecution::SetInput1,
29 &Segmentation_ModelExecution::SetInput2,
30 &Segmentation_ModelExecution::SetInput3,
31 &Segmentation_ModelExecution::SetInput4,
32 &Segmentation_ModelExecution::SetInput5,
33 &Segmentation_ModelExecution::SetInput6,
34 &Segmentation_ModelExecution::SetInput7,
35 &Segmentation_ModelExecution::SetInput8,
36 &Segmentation_ModelExecution::SetInput9,
37 &Segmentation_ModelExecution::SetInput10,
38 &Segmentation_ModelExecution::SetInput11,
39 &Segmentation_ModelExecution::SetInput12,
40 &Segmentation_ModelExecution::SetInput13,
41 &Segmentation_ModelExecution::SetInput14,
42 &Segmentation_ModelExecution::SetInput15,
43 &Segmentation_ModelExecution::SetInput16,
44 &Segmentation_ModelExecution::SetInput17,
45 &Segmentation_ModelExecution::SetInput18,
46 &Segmentation_ModelExecution::SetInput19,
47 &Segmentation_ModelExecution::SetInput20,
48 &Segmentation_ModelExecution::SetInput21,
49 &Segmentation_ModelExecution::SetInput22,
50 &Segmentation_ModelExecution::SetInput23,
51 &Segmentation_ModelExecution::SetInput24,
52 &Segmentation_ModelExecution::SetInput25,
53 &Segmentation_ModelExecution::SetInput26,
54 &Segmentation_ModelExecution::SetInput27,
55 &Segmentation_ModelExecution::SetInput28,
56 &Segmentation_ModelExecution::SetInput29};
57
58// Gets a set of segment IDs that are allowed to upload metrics.
59base::flat_set<int> GetSegmentIdsAllowedForReporting() {
60 std::vector<std::string> segment_ids = base::SplitString(
61 base::GetFieldTrialParamValueByFeature(
62 segmentation_platform::features::
63 kSegmentationStructuredMetricsFeature,
64 segmentation_platform::kSegmentIdsAllowedForReportingKey),
65 ",;", base::WhitespaceHandling::TRIM_WHITESPACE,
66 base::SplitResult::SPLIT_WANT_NONEMPTY);
67 base::flat_set<int> result;
68 for (const auto& id : segment_ids) {
69 int segment_id;
70 if (base::StringToInt(id, &segment_id))
71 result.emplace(segment_id);
72 }
73 return result;
74}
75
76} // namespace
77
78namespace segmentation_platform {
79
80SegmentationUkmHelper::SegmentationUkmHelper() {
81 Initialize();
82}
83
84SegmentationUkmHelper::~SegmentationUkmHelper() = default;
85
86void SegmentationUkmHelper::Initialize() {
87 allowed_segment_ids_ = GetSegmentIdsAllowedForReporting();
88}
89
90// static
91SegmentationUkmHelper* SegmentationUkmHelper::GetInstance() {
92 static base::NoDestructor<SegmentationUkmHelper> helper;
93 return helper.get();
94}
95
96ukm::SourceId SegmentationUkmHelper::RecordModelExecutionResult(
97 OptimizationTarget segment_id,
98 int64_t model_version,
99 const std::vector<float>& input_tensor,
100 float result) {
101 // Check if the |segment_id| is allowed to record metrics.
102 if (!allowed_segment_ids_.contains(static_cast<int>(segment_id)))
103 return ukm::kInvalidSourceId;
104
105 ukm::SourceId source_id = ukm::NoURLSourceId();
106 // Don't record UKM if there are too many tensors.
107 if (input_tensor.size() > ARRAY_SIZE(kSegmentationUkmMethods)) {
108 stats::RecordTooManyInputTensors(input_tensor.size());
109 return ukm::kInvalidSourceId;
110 }
111 ukm::builders::Segmentation_ModelExecution execution_result(source_id);
112
113 execution_result.SetOptimizationTarget(segment_id)
114 .SetModelVersion(model_version);
115
116 for (size_t i = 0; i < input_tensor.size(); ++i) {
117 CALL_MEMBER_FN(execution_result, kSegmentationUkmMethods[i])
118 (FloatToInt64(input_tensor[i]));
119 }
120 execution_result.SetPredictionResult(FloatToInt64(result))
121 .Record(ukm::UkmRecorder::Get());
122 return source_id;
123}
124
125// static
126int64_t SegmentationUkmHelper::FloatToInt64(float f) {
127 // Encode the float number in IEEE754 double precision.
128 return bit_cast<int64_t>(static_cast<double>(f));
129}
130
131} // namespace segmentation_platform