blob: 4cc882080bcd4146cfd4fdc84c5ce763b369cebc [file] [log] [blame]
Robert Ogdenad99d6f62023-05-01 21:40:091// Copyright 2023 The Chromium Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef COMPONENTS_BROWSING_TOPICS_ANNOTATOR_IMPL_H_
6#define COMPONENTS_BROWSING_TOPICS_ANNOTATOR_IMPL_H_
7
8#include <string>
9#include <unordered_map>
10#include <vector>
11
12#include "base/files/file_path.h"
13#include "base/functional/callback.h"
14#include "base/memory/weak_ptr.h"
15#include "base/sequence_checker.h"
16#include "base/task/sequenced_task_runner.h"
17#include "components/browsing_topics/annotator.h"
18#include "components/optimization_guide/core/bert_model_handler.h"
19#include "third_party/abseil-cpp/absl/types/optional.h"
20
21namespace optimization_guide {
22class OptimizationGuideModelProvider;
23}
24
25namespace browsing_topics {
26
27// An implementation of the |Annotator| base class. This Annotator supports
28// concurrent batch annotations and manages the lifetimes of all underlying
29// components. This class must only be owned and called on the UI thread.
30//
31// |BatchAnnotate| is the main entry point for callers. The callback given to
32// |BatchAnnotate| is forwarded through many subsequent PostTasks until all
33// annotations are ready to be returned to the caller.
34//
35// Life of an Annotation:
36// 1. |BatchAnnotate| checks if the override list needs to be loaded. If so, it
37// is done on a background thread. After that check and possibly loading the
38// list in |OnOverrideListLoadAttemptDone|, |StartBatchAnnotate| is called.
39// 2. |StartBatchAnnotate| shares ownership of the |BatchAnnotationCallback|
40// among a series of callbacks (using |base::BarrierClosure|), one for each
41// input. Ownership of the inputs is moved to the heap where all individual
42// model executions can reference their input and set their output.
43// 3. |AnnotateSingleInput| runs a single annotation, first checking the
44// override list if available. If the input is not covered in the override list,
45// the ML model is run on a background thread.
46// 4. |PostprocessCategoriesToBatchAnnotationResult| is called to post-process
47// the output of the ML model.
48// 5. |OnBatchComplete| is called by the barrier closure which passes the
49// annotations back to the caller and unloads the model if no other batches are
50// in progress.
51class AnnotatorImpl : public Annotator,
52 public optimization_guide::BertModelHandler {
53 public:
54 AnnotatorImpl(
55 optimization_guide::OptimizationGuideModelProvider* model_provider,
56 scoped_refptr<base::SequencedTaskRunner> background_task_runner,
57 const absl::optional<optimization_guide::proto::Any>& model_metadata);
58 ~AnnotatorImpl() override;
59
60 // Annotator:
61 void BatchAnnotate(BatchAnnotationCallback callback,
62 const std::vector<std::string>& inputs) override;
63 void NotifyWhenModelAvailable(base::OnceClosure callback) override;
64 absl::optional<optimization_guide::ModelInfo> GetBrowsingTopicsModelInfo()
65 const override;
66
67 //////////////////////////////////////////////////////////////////////////////
68 // Public methods below here are exposed only for testing.
69 //////////////////////////////////////////////////////////////////////////////
70
71 // optimization_guide::BertModelHandler:
72 void OnModelUpdated(
73 optimization_guide::proto::OptimizationTarget optimization_target,
rajendrant23411d82023-08-11 19:56:1774 base::optional_ref<const optimization_guide::ModelInfo> model_info)
75 override;
Robert Ogdenad99d6f62023-05-01 21:40:0976
77 // Extracts the scored categories from the output of the model.
78 absl::optional<std::vector<int32_t>> ExtractCategoriesFromModelOutput(
79 const std::vector<tflite::task::core::Category>& model_output) const;
80
Robert Ogden2f15ed62023-05-03 21:15:2381 protected:
82 // optimization_guide::BertModelHandler:
83 void UnloadModel() override;
84
Robert Ogdenad99d6f62023-05-01 21:40:0985 private:
86 // Sets the |override_list_| after it was loaded on a background thread and
87 // calls |StartBatchAnnotate|.
88 void OnOverrideListLoadAttemptDone(
89 BatchAnnotationCallback callback,
90 const std::vector<std::string>& inputs,
91 absl::optional<std::unordered_map<std::string, std::vector<int32_t>>>
92 override_list);
93
94 // Starts a batch annotation once the override list is loaded, if provided.
95 void StartBatchAnnotate(BatchAnnotationCallback callback,
96 const std::vector<std::string>& inputs);
97
98 // Does the required preprocessing on a input domain.
99 std::string PreprocessHost(const std::string& host) const;
100
101 // Runs a single input through the ML model, setting the result in
102 // |annotation|.
103 void AnnotateSingleInput(base::OnceClosure single_input_done_signal,
104 Annotation* annotation);
105
106 // Called when all single inputs have been annotated and the |callback| from
107 // the caller can finally be run.
108 void OnBatchComplete(
109 BatchAnnotationCallback callback,
110 std::unique_ptr<std::vector<Annotation>> annotations_ptr);
111
Robert Ogdenad99d6f62023-05-01 21:40:09112 // Sets |annotation.topics| from the output of the model, calling
113 // |ExtractCategoriesFromModelOutput| in the process.
114 void PostprocessCategoriesToBatchAnnotationResult(
115 base::OnceClosure single_input_done_signal,
116 Annotation* annotation,
117 const absl::optional<std::vector<tflite::task::core::Category>>& output);
118
119 // Used to read the override list file on a background thread.
120 scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
121
122 // Set whenever a valid override list file is passed along with the model file
123 // update. Used on the UI thread.
124 absl::optional<base::FilePath> override_list_file_path_;
125
126 // Set whenever an override list file is available and the model file is
127 // loaded into memory. Reset whenever the model file is unloaded.
128 // Used on the UI thread. Lookups in this mapping should have |PreprocessHost|
129 // applied first.
130 absl::optional<std::unordered_map<std::string, std::vector<int32_t>>>
131 override_list_;
132
133 // The version of topics model provided by the server in the model metadata
134 // which specifies the expected functionality of execution not contained
135 // within the model itself (e.g., preprocessing/post processing).
136 int version_ = 0;
137
138 // Counts the number of batches that are in progress. This counter is
139 // incremented in |StartBatchAnnotate| and decremented in |OnBatchComplete|.
140 // When this counter is 0 in |OnBatchComplete|, the model in unloaded from
141 // memory.
142 size_t in_progess_batches_ = 0;
143
rajendrant23411d82023-08-11 19:56:17144 // Indicates whether the model received was valid. Model will be invalid when
145 // metadata versions are unsupported.
146 bool is_valid_model_ = false;
147
Robert Ogdenad99d6f62023-05-01 21:40:09148 SEQUENCE_CHECKER(sequence_checker_);
149
150 base::WeakPtrFactory<AnnotatorImpl> weak_ptr_factory_{this};
151};
152
153} // namespace browsing_topics
154
155#endif // COMPONENTS_BROWSING_TOPICS_ANNOTATOR_IMPL_H_