blob: 7ff624eabbf48674566e1420fb1e90ea3488d61e [file] [log] [blame]
Robert Ogdenad99d6f62023-05-01 21:40:091// Copyright 2023 The Chromium Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef COMPONENTS_BROWSING_TOPICS_ANNOTATOR_IMPL_H_
6#define COMPONENTS_BROWSING_TOPICS_ANNOTATOR_IMPL_H_
7
Arthur Sonzognic571efb2024-01-26 20:26:188#include <optional>
Robert Ogdenad99d6f62023-05-01 21:40:099#include <string>
10#include <unordered_map>
11#include <vector>
12
Robert Ogdene4290802023-10-26 23:29:2113#include "base/callback_list.h"
Robert Ogdenad99d6f62023-05-01 21:40:0914#include "base/files/file_path.h"
15#include "base/functional/callback.h"
16#include "base/memory/weak_ptr.h"
17#include "base/sequence_checker.h"
18#include "base/task/sequenced_task_runner.h"
19#include "components/browsing_topics/annotator.h"
20#include "components/optimization_guide/core/bert_model_handler.h"
Robert Ogdenad99d6f62023-05-01 21:40:0921
22namespace optimization_guide {
23class OptimizationGuideModelProvider;
24}
25
26namespace browsing_topics {
27
28// An implementation of the |Annotator| base class. This Annotator supports
29// concurrent batch annotations and manages the lifetimes of all underlying
30// components. This class must only be owned and called on the UI thread.
31//
32// |BatchAnnotate| is the main entry point for callers. The callback given to
33// |BatchAnnotate| is forwarded through many subsequent PostTasks until all
34// annotations are ready to be returned to the caller.
35//
36// Life of an Annotation:
37// 1. |BatchAnnotate| checks if the override list needs to be loaded. If so, it
38// is done on a background thread. After that check and possibly loading the
39// list in |OnOverrideListLoadAttemptDone|, |StartBatchAnnotate| is called.
40// 2. |StartBatchAnnotate| shares ownership of the |BatchAnnotationCallback|
41// among a series of callbacks (using |base::BarrierClosure|), one for each
42// input. Ownership of the inputs is moved to the heap where all individual
43// model executions can reference their input and set their output.
44// 3. |AnnotateSingleInput| runs a single annotation, first checking the
45// override list if available. If the input is not covered in the override list,
46// the ML model is run on a background thread.
47// 4. |PostprocessCategoriesToBatchAnnotationResult| is called to post-process
48// the output of the ML model.
49// 5. |OnBatchComplete| is called by the barrier closure which passes the
50// annotations back to the caller and unloads the model if no other batches are
51// in progress.
52class AnnotatorImpl : public Annotator,
53 public optimization_guide::BertModelHandler {
54 public:
55 AnnotatorImpl(
56 optimization_guide::OptimizationGuideModelProvider* model_provider,
57 scoped_refptr<base::SequencedTaskRunner> background_task_runner,
Arthur Sonzognic571efb2024-01-26 20:26:1858 const std::optional<optimization_guide::proto::Any>& model_metadata);
Robert Ogdenad99d6f62023-05-01 21:40:0959 ~AnnotatorImpl() override;
60
61 // Annotator:
62 void BatchAnnotate(BatchAnnotationCallback callback,
63 const std::vector<std::string>& inputs) override;
64 void NotifyWhenModelAvailable(base::OnceClosure callback) override;
Arthur Sonzognic571efb2024-01-26 20:26:1865 std::optional<optimization_guide::ModelInfo> GetBrowsingTopicsModelInfo()
Robert Ogdenad99d6f62023-05-01 21:40:0966 const override;
67
68 //////////////////////////////////////////////////////////////////////////////
69 // Public methods below here are exposed only for testing.
70 //////////////////////////////////////////////////////////////////////////////
71
72 // optimization_guide::BertModelHandler:
73 void OnModelUpdated(
74 optimization_guide::proto::OptimizationTarget optimization_target,
rajendrant23411d82023-08-11 19:56:1775 base::optional_ref<const optimization_guide::ModelInfo> model_info)
76 override;
Robert Ogdenad99d6f62023-05-01 21:40:0977
78 // Extracts the scored categories from the output of the model.
Arthur Sonzognic571efb2024-01-26 20:26:1879 std::optional<std::vector<int32_t>> ExtractCategoriesFromModelOutput(
Robert Ogdenad99d6f62023-05-01 21:40:0980 const std::vector<tflite::task::core::Category>& model_output) const;
81
Robert Ogden2f15ed62023-05-03 21:15:2382 protected:
83 // optimization_guide::BertModelHandler:
84 void UnloadModel() override;
85
Robert Ogdenad99d6f62023-05-01 21:40:0986 private:
87 // Sets the |override_list_| after it was loaded on a background thread and
88 // calls |StartBatchAnnotate|.
89 void OnOverrideListLoadAttemptDone(
90 BatchAnnotationCallback callback,
91 const std::vector<std::string>& inputs,
Arthur Sonzognic571efb2024-01-26 20:26:1892 std::optional<std::unordered_map<std::string, std::vector<int32_t>>>
Robert Ogdenad99d6f62023-05-01 21:40:0993 override_list);
94
95 // Starts a batch annotation once the override list is loaded, if provided.
96 void StartBatchAnnotate(BatchAnnotationCallback callback,
97 const std::vector<std::string>& inputs);
98
99 // Does the required preprocessing on a input domain.
100 std::string PreprocessHost(const std::string& host) const;
101
102 // Runs a single input through the ML model, setting the result in
103 // |annotation|.
104 void AnnotateSingleInput(base::OnceClosure single_input_done_signal,
105 Annotation* annotation);
106
107 // Called when all single inputs have been annotated and the |callback| from
108 // the caller can finally be run.
109 void OnBatchComplete(
110 BatchAnnotationCallback callback,
111 std::unique_ptr<std::vector<Annotation>> annotations_ptr);
112
Robert Ogdenad99d6f62023-05-01 21:40:09113 // Sets |annotation.topics| from the output of the model, calling
114 // |ExtractCategoriesFromModelOutput| in the process.
115 void PostprocessCategoriesToBatchAnnotationResult(
116 base::OnceClosure single_input_done_signal,
117 Annotation* annotation,
Arthur Sonzognic571efb2024-01-26 20:26:18118 const std::optional<std::vector<tflite::task::core::Category>>& output);
Robert Ogdenad99d6f62023-05-01 21:40:09119
120 // Used to read the override list file on a background thread.
121 scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
122
123 // Set whenever a valid override list file is passed along with the model file
124 // update. Used on the UI thread.
Arthur Sonzognic571efb2024-01-26 20:26:18125 std::optional<base::FilePath> override_list_file_path_;
Robert Ogdenad99d6f62023-05-01 21:40:09126
127 // Set whenever an override list file is available and the model file is
128 // loaded into memory. Reset whenever the model file is unloaded.
129 // Used on the UI thread. Lookups in this mapping should have |PreprocessHost|
130 // applied first.
Arthur Sonzognic571efb2024-01-26 20:26:18131 std::optional<std::unordered_map<std::string, std::vector<int32_t>>>
Robert Ogdenad99d6f62023-05-01 21:40:09132 override_list_;
133
134 // The version of topics model provided by the server in the model metadata
135 // which specifies the expected functionality of execution not contained
136 // within the model itself (e.g., preprocessing/post processing).
137 int version_ = 0;
138
139 // Counts the number of batches that are in progress. This counter is
140 // incremented in |StartBatchAnnotate| and decremented in |OnBatchComplete|.
141 // When this counter is 0 in |OnBatchComplete|, the model in unloaded from
142 // memory.
143 size_t in_progess_batches_ = 0;
144
Robert Ogdene4290802023-10-26 23:29:21145 // Callbacks that are run when the model is updated with the correct taxonomy
146 // version.
147 base::OnceClosureList model_available_callbacks_;
rajendrant23411d82023-08-11 19:56:17148
Robert Ogdenad99d6f62023-05-01 21:40:09149 SEQUENCE_CHECKER(sequence_checker_);
150
151 base::WeakPtrFactory<AnnotatorImpl> weak_ptr_factory_{this};
152};
153
154} // namespace browsing_topics
155
156#endif // COMPONENTS_BROWSING_TOPICS_ANNOTATOR_IMPL_H_