blob: 00909eca4c98e3b26d1cc3bccb9b0b711ec619fa [file] [log] [blame]
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/40285824): Remove this and convert code to safer constructs.
#pragma allow_unsafe_buffers
#endif
#include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
#include "base/bit_cast.h"
#include "base/metrics/field_trial_params.h"
#include "base/rand_util.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_split.h"
#include "base/time/clock.h"
#include "base/time/time.h"
#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/features.h"
#include "components/segmentation_platform/public/local_state_helper.h"
#include "services/metrics/public/cpp/ukm_builders.h"
#include "services/metrics/public/cpp/ukm_recorder.h"
#include "services/metrics/public/cpp/ukm_source_id.h"
#define CALL_MEMBER_FN(obj, func) ((obj).*(func))
#define ARRAY_SIZE(x) (sizeof(x) / sizeof(x)[0])
using segmentation_platform::SegmentationUkmHelper;
using segmentation_platform::proto::SegmentId;
using ukm::builders::Segmentation_ModelExecution;
namespace segmentation_platform {
namespace {
using UkmMemberFn =
Segmentation_ModelExecution& (Segmentation_ModelExecution::*)(int64_t);
const UkmMemberFn kSegmentationUkmInputMethods[] = {
&Segmentation_ModelExecution::SetInput0,
&Segmentation_ModelExecution::SetInput1,
&Segmentation_ModelExecution::SetInput2,
&Segmentation_ModelExecution::SetInput3,
&Segmentation_ModelExecution::SetInput4,
&Segmentation_ModelExecution::SetInput5,
&Segmentation_ModelExecution::SetInput6,
&Segmentation_ModelExecution::SetInput7,
&Segmentation_ModelExecution::SetInput8,
&Segmentation_ModelExecution::SetInput9,
&Segmentation_ModelExecution::SetInput10,
&Segmentation_ModelExecution::SetInput11,
&Segmentation_ModelExecution::SetInput12,
&Segmentation_ModelExecution::SetInput13,
&Segmentation_ModelExecution::SetInput14,
&Segmentation_ModelExecution::SetInput15,
&Segmentation_ModelExecution::SetInput16,
&Segmentation_ModelExecution::SetInput17,
&Segmentation_ModelExecution::SetInput18,
&Segmentation_ModelExecution::SetInput19,
&Segmentation_ModelExecution::SetInput20,
&Segmentation_ModelExecution::SetInput21,
&Segmentation_ModelExecution::SetInput22,
&Segmentation_ModelExecution::SetInput23,
&Segmentation_ModelExecution::SetInput24,
&Segmentation_ModelExecution::SetInput25,
&Segmentation_ModelExecution::SetInput26,
&Segmentation_ModelExecution::SetInput27,
&Segmentation_ModelExecution::SetInput28,
&Segmentation_ModelExecution::SetInput29,
&Segmentation_ModelExecution::SetInput30,
&Segmentation_ModelExecution::SetInput31,
&Segmentation_ModelExecution::SetInput32,
&Segmentation_ModelExecution::SetInput33,
&Segmentation_ModelExecution::SetInput34,
&Segmentation_ModelExecution::SetInput35,
&Segmentation_ModelExecution::SetInput36,
&Segmentation_ModelExecution::SetInput37,
&Segmentation_ModelExecution::SetInput38,
&Segmentation_ModelExecution::SetInput39,
&Segmentation_ModelExecution::SetInput40,
&Segmentation_ModelExecution::SetInput41,
&Segmentation_ModelExecution::SetInput42,
&Segmentation_ModelExecution::SetInput43,
&Segmentation_ModelExecution::SetInput44,
&Segmentation_ModelExecution::SetInput45,
&Segmentation_ModelExecution::SetInput46,
&Segmentation_ModelExecution::SetInput47,
&Segmentation_ModelExecution::SetInput48,
&Segmentation_ModelExecution::SetInput49};
const UkmMemberFn kSegmentationUkmPredictionResultMethods[] = {
&Segmentation_ModelExecution::SetPredictionResult1,
&Segmentation_ModelExecution::SetPredictionResult2,
&Segmentation_ModelExecution::SetPredictionResult3,
&Segmentation_ModelExecution::SetPredictionResult4,
&Segmentation_ModelExecution::SetPredictionResult5,
&Segmentation_ModelExecution::SetPredictionResult6,
&Segmentation_ModelExecution::SetPredictionResult7,
&Segmentation_ModelExecution::SetPredictionResult8,
&Segmentation_ModelExecution::SetPredictionResult9,
&Segmentation_ModelExecution::SetPredictionResult10};
const UkmMemberFn kSegmentationUkmOutputMethods[] = {
&Segmentation_ModelExecution::SetActualResult,
&Segmentation_ModelExecution::SetActualResult2,
&Segmentation_ModelExecution::SetActualResult3,
&Segmentation_ModelExecution::SetActualResult4,
&Segmentation_ModelExecution::SetActualResult5,
&Segmentation_ModelExecution::SetActualResult6};
// 1 out of 100 model execution will be reported.
const int kDefaultModelExecutionSamplingRate = 100;
int GetModelExecutionSamplingRate() {
return base::GetFieldTrialParamByFeatureAsInt(
segmentation_platform::features::
kSegmentationPlatformModelExecutionSampling,
segmentation_platform::kModelExecutionSamplingRateKey,
kDefaultModelExecutionSamplingRate);
}
// Helper method to add model prediction results to UKM log.
void AddPredictionResultToUkmModelExecution(
ukm::builders::Segmentation_ModelExecution* model_execution,
const std::vector<float>& results) {
CHECK_LE(results.size(), ARRAY_SIZE(kSegmentationUkmPredictionResultMethods));
for (size_t i = 0; i < results.size(); ++i) {
CALL_MEMBER_FN(*model_execution, kSegmentationUkmPredictionResultMethods[i])
(SegmentationUkmHelper::FloatToInt64(results[i]));
}
}
std::string GetDebugString(const ModelProvider::Request& input_tensor,
const ModelProvider::Response& outputs) {
std::stringstream out;
out << "Inputs: ";
int j = 0;
for (const auto& i : input_tensor) {
out << j++ << ":" << i << " ";
}
out << " Outputs: ";
j = 0;
for (const auto& i : outputs) {
out << j++ << ":" << i << " ";
}
return out.str();
}
} // namespace
SegmentationUkmHelper::SegmentationUkmHelper() {
Initialize();
}
SegmentationUkmHelper::~SegmentationUkmHelper() = default;
// static
SegmentationUkmHelper* SegmentationUkmHelper::GetInstance() {
static base::NoDestructor<SegmentationUkmHelper> helper;
return helper.get();
}
void SegmentationUkmHelper::Initialize() {
// TODO(crbug.com/40887237): Migrate models for these segments to use
// `upload_tensors`.
allowed_segment_ids_.clear();
if (base::FeatureList::IsEnabled(segmentation_platform::features::
kSegmentationDefaultReportingSegments)) {
allowed_segment_ids_ = base::flat_set<SegmentId>{
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE,
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE,
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY};
}
sampling_rate_ = GetModelExecutionSamplingRate();
DCHECK_GE(sampling_rate_, 0);
}
ukm::SourceId SegmentationUkmHelper::RecordModelExecutionResult(
SegmentId segment_id,
int64_t model_version,
const ModelProvider::Request& input_tensor,
const std::vector<float>& results) {
ukm::SourceId source_id = ukm::NoURLSourceId();
// Do some sampling before sending out UKM.
if (sampling_rate_ == 0) {
return source_id;
}
if (base::RandInt(1, sampling_rate_) > 1) {
return source_id;
}
ukm::builders::Segmentation_ModelExecution execution_result(source_id);
// Add inputs to ukm message.
if (!AddInputsToUkm(&execution_result, segment_id, model_version,
input_tensor)) {
return ukm::kInvalidSourceId;
}
AddPredictionResultToUkmModelExecution(&execution_result, results);
execution_result.Record(ukm::UkmRecorder::Get());
return source_id;
}
ukm::SourceId SegmentationUkmHelper::RecordTrainingData(
SegmentId segment_id,
int64_t model_version,
ukm::SourceId ukm_source_id,
const ModelProvider::Request& input_tensor,
const ModelProvider::Response& outputs,
const std::vector<int>& output_indexes,
std::optional<proto::PredictionResult> prediction_result,
std::optional<SelectedSegment> selected_segment) {
ukm::SourceId source_id = ukm_source_id != ukm::kInvalidSourceId
? ukm_source_id
: ukm::NoURLSourceId();
ukm::builders::Segmentation_ModelExecution execution_result(source_id);
if (!AddInputsToUkm(&execution_result, segment_id, model_version,
input_tensor)) {
return ukm::kInvalidSourceId;
}
if (!AddOutputsToUkm(&execution_result, outputs, output_indexes)) {
return ukm::kInvalidSourceId;
}
if (prediction_result.has_value() && prediction_result->result_size() > 0) {
std::vector<float> results(prediction_result->result().begin(),
prediction_result->result().end());
AddPredictionResultToUkmModelExecution(&execution_result, results);
base::Time prediction_time = base::Time::FromDeltaSinceWindowsEpoch(
base::Microseconds(prediction_result->timestamp_us()));
execution_result.SetOutputDelaySec(
(base::Time::Now() - prediction_time).InSeconds());
}
if (selected_segment.has_value()) {
execution_result.SetSelectionResult(selected_segment->segment_id);
execution_result.SetOutputDelaySec(
(base::Time::Now() - selected_segment->selection_time).InSeconds());
}
VLOG(1) << "Recording training data " << proto::SegmentId_Name(segment_id)
<< " " << GetDebugString(input_tensor, outputs);
execution_result.Record(ukm::UkmRecorder::Get());
return source_id;
}
bool SegmentationUkmHelper::AddInputsToUkm(
ukm::builders::Segmentation_ModelExecution* ukm_builder,
SegmentId segment_id,
int64_t model_version,
const ModelProvider::Request& input_tensor) {
if (input_tensor.size() > ARRAY_SIZE(kSegmentationUkmInputMethods)) {
// Don't record UKM if there are too many tensors.
stats::RecordTooManyInputTensors(input_tensor.size());
return false;
}
ukm_builder->SetOptimizationTarget(segment_id).SetModelVersion(model_version);
for (size_t i = 0; i < input_tensor.size(); ++i) {
CALL_MEMBER_FN(*ukm_builder, kSegmentationUkmInputMethods[i])
(FloatToInt64(input_tensor[i]));
}
return true;
}
bool SegmentationUkmHelper::AddOutputsToUkm(
ukm::builders::Segmentation_ModelExecution* ukm_builder,
const ModelProvider::Response& outputs,
const std::vector<int>& output_indexes) {
DCHECK(!outputs.empty());
if (outputs.size() != output_indexes.size())
return false;
const int output_methods_size = ARRAY_SIZE(kSegmentationUkmOutputMethods);
if (outputs.size() > output_methods_size)
return false;
for (size_t i = 0; i < outputs.size(); ++i) {
if (output_indexes[i] >= output_methods_size)
return false;
CALL_MEMBER_FN(*ukm_builder,
kSegmentationUkmOutputMethods[output_indexes[i]])
(FloatToInt64(outputs[i]));
}
return true;
}
bool SegmentationUkmHelper::IsUploadRequested(
const proto::SegmentInfo& segment_info) const {
return segment_info.model_metadata().upload_tensors() ||
allowed_segment_ids_.contains(segment_info.segment_id());
}
// static
int64_t SegmentationUkmHelper::FloatToInt64(float f) {
// Encode the float number in IEEE754 double precision.
return base::bit_cast<int64_t>(static_cast<double>(f));
}
// static
bool SegmentationUkmHelper::AllowedToUploadData(
base::TimeDelta signal_storage_length,
base::Clock* clock) {
base::Time most_recent_allowed = LocalStateHelper::GetInstance().GetPrefTime(
kSegmentationUkmMostRecentAllowedTimeKey);
// If the local state is never set, return false.
if (most_recent_allowed.is_null() ||
most_recent_allowed == base::Time::Max()) {
VLOG(1) << "UKM consent not granted";
return false;
}
if (most_recent_allowed + signal_storage_length < clock->Now()) {
return true;
} else {
VLOG(1) << "UKM consent granted on: " << most_recent_allowed
<< ". Waiting for the model's storage period ("
<< most_recent_allowed + signal_storage_length
<< ") to avoid uploading data collected pre-consent";
return false;
}
}
} // namespace segmentation_platform