blob: 3aa968a8c918d45062d547440b5d068eeba29c5d [file] [log] [blame]
Nayeem Jahan Rafi42c80992024-10-06 19:33:201// Copyright 2024 The Chromium Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "components/manta/walrus_provider.h"
6
7#include <memory>
8#include <string>
9
10#include "base/strings/stringprintf.h"
11#include "base/test/bind.h"
12#include "base/test/metrics/histogram_tester.h"
13#include "base/test/task_environment.h"
14#include "base/time/time.h"
15#include "components/manta/base_provider.h"
16#include "components/manta/base_provider_test_helper.h"
17#include "components/manta/manta_status.h"
18#include "components/manta/proto/manta.pb.h"
19#include "components/signin/public/base/consent_level.h"
20#include "components/signin/public/identity_manager/identity_manager.h"
21#include "components/signin/public/identity_manager/identity_test_environment.h"
22#include "net/base/net_errors.h"
23#include "net/http/http_status_code.h"
24#include "net/http/http_util.h"
25#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
26#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
27#include "services/network/test/test_url_loader_factory.h"
28#include "testing/gtest/include/gtest/gtest.h"
Nayeem Jahan Rafi39c58cb2024-11-06 06:13:1229#include "ui/gfx/codec/jpeg_codec.h"
30#include "ui/gfx/image/image_skia_operations.h"
Nayeem Jahan Rafi42c80992024-10-06 19:33:2031
32namespace manta {
33
34namespace {
35constexpr char kMockEndpoint[] = "https://my-endpoint.com";
Nayeem Jahan Rafi39c58cb2024-11-06 06:13:1236
37std::vector<uint8_t> CreateJPGBytes(int width, int height) {
38 SkBitmap bitmap;
39 bitmap.allocN32Pixels(width, height);
40 bitmap.eraseColor(SK_ColorRED); // Fill with a solid color
41 auto image_bytes = gfx::JPEGCodec::Encode(bitmap, 100);
42 return image_bytes.value();
43}
Nayeem Jahan Rafi42c80992024-10-06 19:33:2044}
45
46class FakeWalrusProvider : public WalrusProvider, public FakeBaseProvider {
47 public:
48 FakeWalrusProvider(
49 scoped_refptr<network::SharedURLLoaderFactory> test_url_loader_factory,
50 signin::IdentityManager* identity_manager)
51 : BaseProvider(test_url_loader_factory, identity_manager),
52 WalrusProvider(test_url_loader_factory,
53 identity_manager,
54 ProviderParams()),
55 FakeBaseProvider(test_url_loader_factory, identity_manager) {}
Nayeem Jahan Rafi39c58cb2024-11-06 06:13:1256 std::optional<std::vector<uint8_t>> DownscaleImageIfNeeded(
57 const std::vector<uint8_t>& image_bytes,
58 int32_t max_pixels_after_resizing) {
59 return WalrusProvider::DownscaleImageIfNeeded(image_bytes,
60 max_pixels_after_resizing);
61 }
Nayeem Jahan Rafi42c80992024-10-06 19:33:2062};
63
64class WalrusProviderTest : public BaseProviderTest {
65 public:
66 WalrusProviderTest() = default;
67
68 WalrusProviderTest(const WalrusProviderTest&) = delete;
69 WalrusProviderTest& operator=(const WalrusProviderTest&) = delete;
70
71 ~WalrusProviderTest() override = default;
72
73 std::unique_ptr<FakeWalrusProvider> CreateWalrusProvider() {
74 return std::make_unique<FakeWalrusProvider>(
75 base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>(
76 &test_url_loader_factory_),
77 identity_test_env_->identity_manager());
78 }
79};
80
81// Test that responses with http_status_code != net::HTTP_OK are captured.
82TEST_F(WalrusProviderTest, CaptureUnexcpetedStatusCode) {
83 std::unique_ptr<FakeWalrusProvider> walrus_provider = CreateWalrusProvider();
84
85 SetEndpointMockResponse(GURL{kMockEndpoint}, /*response_data=*/"",
86 net::HTTP_BAD_REQUEST, net::OK);
87 std::optional<std::string> text_prompt = "text pompt";
88 std::vector<std::vector<uint8_t>> images;
89
90 walrus_provider->Filter(
91 text_prompt, images,
92 base::BindLambdaForTesting(
93 [quit_closure = task_environment_.QuitClosure()](
94 base::Value::Dict response, MantaStatus manta_status) {
95 EXPECT_EQ(manta_status.status_code,
96 MantaStatusCode::kBackendFailure);
97 quit_closure.Run();
98 }));
99 task_environment_.RunUntilQuit();
100}
101
102// Test that responses with network errors are captured.
103TEST_F(WalrusProviderTest, CaptureNetError) {
104 std::unique_ptr<FakeWalrusProvider> walrus_provider = CreateWalrusProvider();
105
106 SetEndpointMockResponse(GURL{kMockEndpoint}, /*response_data=*/"",
107 net::HTTP_OK, net::ERR_FAILED);
108 std::optional<std::string> text_prompt = "text pompt";
109 std::vector<std::vector<uint8_t>> images;
110
111 walrus_provider->Filter(
112 text_prompt, images,
113 base::BindLambdaForTesting(
114 [quit_closure = task_environment_.QuitClosure()](
115 base::Value::Dict response, MantaStatus manta_status) {
116 EXPECT_EQ(manta_status.status_code,
117 MantaStatusCode::kNoInternetConnection);
118 quit_closure.Run();
119 }));
120 task_environment_.RunUntilQuit();
121}
122
123// Test Manta Provider rejects invalid input data. Currently we require the
124// input must contain a valid text prompt or an image.
125TEST_F(WalrusProviderTest, InvalidInput) {
126 std::unique_ptr<FakeWalrusProvider> walrus_provider = CreateWalrusProvider();
127
128 SetEndpointMockResponse(GURL{kMockEndpoint}, /*response_data=*/"",
129 net::HTTP_OK, net::OK);
130 std::optional<std::string> text_prompt;
131 std::vector<std::vector<uint8_t>> images;
132
133 walrus_provider->Filter(
134 text_prompt, images,
135 base::BindLambdaForTesting(
136 [quit_closure = task_environment_.QuitClosure()](
137 base::Value::Dict response, MantaStatus manta_status) {
138 EXPECT_EQ(manta_status.status_code, MantaStatusCode::kInvalidInput);
139 quit_closure.Run();
140 }));
141 task_environment_.RunUntilQuit();
142}
143
144// Test the response when the text prompt / image is safe.
145TEST_F(WalrusProviderTest, SuccessfulResponse) {
146 std::string image_bytes = "image_bytes";
147 base::HistogramTester histogram_tester;
148 manta::proto::Response response;
149 auto* output_data = response.add_output_data();
150 output_data->set_text("text pompt");
151 output_data = response.add_output_data();
152 output_data->mutable_image()->set_serialized_bytes(image_bytes);
153
154 std::string response_data;
155 response.SerializeToString(&response_data);
156
157 SetEndpointMockResponse(GURL{kMockEndpoint}, response_data, net::HTTP_OK,
158 net::OK);
159 std::unique_ptr<FakeWalrusProvider> walrus_provider = CreateWalrusProvider();
160 auto quit_closure = task_environment_.QuitClosure();
161 std::optional<std::string> text_prompt = "text pompt";
162 std::vector<std::vector<uint8_t>> images = {
163 std::vector<uint8_t>(image_bytes.begin(), image_bytes.end())};
164
165 walrus_provider->Filter(
166 text_prompt, images,
167 base::BindLambdaForTesting([&quit_closure](base::Value::Dict response,
168 MantaStatus manta_status) {
169 // Even though the response has text and image, walrus just
170 // returns the status code
171 ASSERT_EQ(MantaStatusCode::kOk, manta_status.status_code);
172 ASSERT_TRUE(response.empty());
173 quit_closure.Run();
174 }));
175 task_environment_.RunUntilQuit();
176
177 // Metric is logged when response is successfully parsed.
178 histogram_tester.ExpectTotalCount("Ash.MantaService.WalrusProvider.TimeCost",
179 1);
180}
181
182// Test the response when the text prompt is blocked.
183TEST_F(WalrusProviderTest, TextBlocked) {
184 std::string image_bytes = "image_bytes";
185 base::HistogramTester histogram_tester;
186 manta::proto::Response response;
187 manta::proto::FilteredData& filtered_data = *response.add_filtered_data();
188 filtered_data.set_reason(manta::proto::FilteredReason::TEXT_SAFETY);
189 std::string response_data;
190 response.SerializeToString(&response_data);
191
192 SetEndpointMockResponse(GURL{kMockEndpoint}, response_data, net::HTTP_OK,
193 net::OK);
194 std::unique_ptr<FakeWalrusProvider> walrus_provider = CreateWalrusProvider();
195 auto quit_closure = task_environment_.QuitClosure();
196 std::optional<std::string> text_prompt = "text pompt";
197 std::vector<std::vector<uint8_t>> images = {
198 std::vector<uint8_t>(image_bytes.begin(), image_bytes.end())};
199
200 walrus_provider->Filter(
201 text_prompt, images,
202 base::BindLambdaForTesting([&quit_closure](base::Value::Dict response,
203 MantaStatus manta_status) {
204 // Even though the response has text and image, walrus just
205 // returns the status code.
206 ASSERT_EQ(MantaStatusCode::kBlockedOutputs, manta_status.status_code);
207 ASSERT_EQ(response.size(), 1u);
208 ASSERT_TRUE(response.FindBool("text_blocked"));
209 quit_closure.Run();
210 }));
211 task_environment_.RunUntilQuit();
212
213 // Metric is logged when response is successfully parsed.
214 histogram_tester.ExpectTotalCount("Ash.MantaService.WalrusProvider.TimeCost",
215 1);
216}
217
218// Test the response when the text prompt and images is blocked.
219TEST_F(WalrusProviderTest, TextImageBothBlocked) {
220 std::string image_bytes = "image_bytes";
221 base::HistogramTester histogram_tester;
222 manta::proto::Response response;
223 auto* filtered_data = response.add_filtered_data();
224 filtered_data->set_reason(manta::proto::FilteredReason::TEXT_SAFETY);
225 filtered_data = response.add_filtered_data();
226 filtered_data->set_reason(manta::proto::FilteredReason::IMAGE_SAFETY);
227 filtered_data = response.add_filtered_data();
228 filtered_data->set_reason(manta::proto::FilteredReason::IMAGE_SAFETY);
229 std::string response_data;
230 response.SerializeToString(&response_data);
231
232 SetEndpointMockResponse(GURL{kMockEndpoint}, response_data, net::HTTP_OK,
233 net::OK);
234 std::unique_ptr<FakeWalrusProvider> walrus_provider = CreateWalrusProvider();
235 auto quit_closure = task_environment_.QuitClosure();
236 std::optional<std::string> text_prompt = "text pompt";
237 std::vector<std::vector<uint8_t>> images = {
238 std::vector<uint8_t>(image_bytes.begin(), image_bytes.end()),
239 std::vector<uint8_t>(image_bytes.begin(), image_bytes.end())};
240
241 walrus_provider->Filter(
242 text_prompt, images,
243 base::BindLambdaForTesting([&quit_closure](base::Value::Dict response,
244 MantaStatus manta_status) {
245 // Even though the response has text and image, walrus just
246 // returns the status code
247 ASSERT_EQ(MantaStatusCode::kBlockedOutputs, manta_status.status_code);
248 ASSERT_EQ(response.size(), 2u);
249 ASSERT_TRUE(response.FindBool("text_blocked"));
250 ASSERT_TRUE(response.FindBool("image_blocked"));
251 quit_closure.Run();
252 }));
253 task_environment_.RunUntilQuit();
254
255 // Metric is logged when response is successfully parsed.
256 histogram_tester.ExpectTotalCount("Ash.MantaService.WalrusProvider.TimeCost",
257 1);
258}
259
260TEST_F(WalrusProviderTest, EmptyResponseAfterIdentityManagerShutdown) {
261 base::HistogramTester histogram_tester;
Nayeem Jahan Rafi39c58cb2024-11-06 06:13:12262 std::unique_ptr<FakeWalrusProvider> walrus_provider = CreateWalrusProvider();
Nayeem Jahan Rafi42c80992024-10-06 19:33:20263
264 identity_test_env_.reset();
265
266 std::string text_prompt = "text pompt";
Nayeem Jahan Rafi39c58cb2024-11-06 06:13:12267 walrus_provider->Filter(
Nayeem Jahan Rafi42c80992024-10-06 19:33:20268 text_prompt, base::BindLambdaForTesting(
269 [quit_closure = task_environment_.QuitClosure()](
270 base::Value::Dict dict, MantaStatus manta_status) {
271 ASSERT_TRUE(dict.empty());
272 ASSERT_EQ(MantaStatusCode::kNoIdentityManager,
273 manta_status.status_code);
274 quit_closure.Run();
275 }));
276 task_environment_.RunUntilQuit();
277
278 // No metric logged.
279 histogram_tester.ExpectTotalCount("Ash.MantaService.WalrusProvider.TimeCost",
280 0);
281}
282
Nayeem Jahan Rafi39c58cb2024-11-06 06:13:12283TEST_F(WalrusProviderTest, InvalidOrUnknownImageFormatIsNotDownscaled) {
284 std::unique_ptr<FakeWalrusProvider> walrus_provider = CreateWalrusProvider();
285 std::vector<uint8_t> invalid_image_bytes = {1, 2, 3, 4, 5};
286
287 std::optional<std::vector<uint8_t>> resized_image =
288 walrus_provider->DownscaleImageIfNeeded(invalid_image_bytes, 100);
289
290 ASSERT_FALSE(resized_image.has_value());
291}
292
293TEST_F(WalrusProviderTest, LargerImageIsDownscaled) {
294 std::unique_ptr<FakeWalrusProvider> walrus_provider = CreateWalrusProvider();
295 std::vector<uint8_t> image_bytes = CreateJPGBytes(20, 40);
296
297 std::optional<std::vector<uint8_t>> resized_image_bytes =
298 walrus_provider->DownscaleImageIfNeeded(image_bytes, 10 * 10);
299
300 ASSERT_TRUE(resized_image_bytes.has_value());
301 auto resized_image = gfx::JPEGCodec::Decode(resized_image_bytes.value());
302
303 ASSERT_TRUE(resized_image.height() > 0 && resized_image.width() > 0);
304 ASSERT_EQ(resized_image.width(), 7);
305 ASSERT_EQ(resized_image.height(), 14);
306}
307
308TEST_F(WalrusProviderTest, SmallerImageIsNotDownscaled) {
309 std::unique_ptr<FakeWalrusProvider> walrus_provider = CreateWalrusProvider();
310 std::vector<uint8_t> image_bytes = CreateJPGBytes(5, 12);
311
312 std::optional<std::vector<uint8_t>> resized_image_bytes =
313 walrus_provider->DownscaleImageIfNeeded(image_bytes, 10 * 10);
314 ASSERT_TRUE(resized_image_bytes.has_value());
315 auto resized_image = gfx::JPEGCodec::Decode(resized_image_bytes.value());
316
317 ASSERT_TRUE(resized_image.height() > 0 && resized_image.width() > 0);
318 ASSERT_EQ(resized_image.width(), 5);
319 ASSERT_EQ(resized_image.height(), 12);
320}
321
Nayeem Jahan Rafi0ac1ec42024-12-17 13:56:45322// Test the response when the generated region is provided.
323TEST_F(WalrusProviderTest, GeneratedRegion) {
324 std::string generated_region_image_bytes = "generated_region_image_bytes";
325 manta::proto::Response response;
326 auto* output_data = response.add_output_data();
327 output_data->set_text("text prompt");
328
329 output_data = response.add_output_data();
330 // Aratea should return the tag instead of image
331 output_data->set_text("generated_region");
332
333 std::string response_data;
334 response.SerializeToString(&response_data);
335
336 SetEndpointMockResponse(GURL{kMockEndpoint}, response_data, net::HTTP_OK,
337 net::OK);
338 std::unique_ptr<FakeWalrusProvider> walrus_provider = CreateWalrusProvider();
339 auto quit_closure = task_environment_.QuitClosure();
340 std::optional<std::string> text_prompt = "text pompt";
341 std::vector<std::vector<uint8_t>> images = {
342 std::vector<uint8_t>(generated_region_image_bytes.begin(),
343 generated_region_image_bytes.end())};
344 std::vector<manta::WalrusProvider::ImageType> image_types = {
345 manta::WalrusProvider::ImageType::kGeneratedRegion};
346
347 walrus_provider->Filter(
348 text_prompt, images, image_types,
349 base::BindLambdaForTesting([&quit_closure](base::Value::Dict response,
350 MantaStatus manta_status) {
351 // Even though the response has text and image, walrus just
352 // returns the status code
353 ASSERT_EQ(MantaStatusCode::kOk, manta_status.status_code);
354 ASSERT_TRUE(response.empty());
355 quit_closure.Run();
356 }));
357 task_environment_.RunUntilQuit();
358}
359
360// Test the response when the image type argument size mismatch with number of
361// images.
362TEST_F(WalrusProviderTest, ImageTypeSizeMismatch) {
363 std::string generated_region_image_bytes = "generated_region_image_bytes";
364 manta::proto::Response response;
365 auto* output_data = response.add_output_data();
366 output_data->set_text("text prompt");
367
368 output_data = response.add_output_data();
369 // Aratea should return the tag instead of image
370 output_data->set_text("output_image");
371
372 std::string response_data;
373 response.SerializeToString(&response_data);
374
375 SetEndpointMockResponse(GURL{kMockEndpoint}, response_data, net::HTTP_OK,
376 net::OK);
377 std::unique_ptr<FakeWalrusProvider> walrus_provider = CreateWalrusProvider();
378 auto quit_closure = task_environment_.QuitClosure();
379 std::optional<std::string> text_prompt = "text pompt";
380 std::vector<std::vector<uint8_t>> images = {
381 std::vector<uint8_t>(generated_region_image_bytes.begin(),
382 generated_region_image_bytes.end())};
383 std::vector<manta::WalrusProvider::ImageType> image_types = {
384 manta::WalrusProvider::ImageType::kOutputImage,
385 manta::WalrusProvider::ImageType::kGeneratedRegion};
386
387 walrus_provider->Filter(
388 text_prompt, images, image_types,
389 base::BindLambdaForTesting([&quit_closure](base::Value::Dict response,
390 MantaStatus manta_status) {
391 EXPECT_EQ(manta_status.status_code, MantaStatusCode::kInvalidInput);
392 quit_closure.Run();
393 }));
394 task_environment_.RunUntilQuit();
395}
396
Nayeem Jahan Rafi42c80992024-10-06 19:33:20397} // namespace manta