diff --git a/ai_gateway/api/v2/code/completions.py b/ai_gateway/api/v2/code/completions.py index d6a0b266cba74638f6e703dd03c93efbefd00e9d..b01bf5966646742d926718b9d5c93e89c1a793d1 100644 --- a/ai_gateway/api/v2/code/completions.py +++ b/ai_gateway/api/v2/code/completions.py @@ -13,6 +13,7 @@ from ai_gateway.api.snowplow_context import get_snowplow_code_suggestion_context from ai_gateway.api.v2.code.typing import ( CompletionsRequestV1, CompletionsRequestV2, + CompletionsRequestV3, GenerationsRequestV1, GenerationsRequestV2, GenerationsRequestV3, @@ -73,7 +74,7 @@ request_log = get_request_logger("codesuggestions") router = APIRouter() CompletionsRequestWithVersion = Annotated[ - Union[CompletionsRequestV1, CompletionsRequestV2], + Union[CompletionsRequestV1, CompletionsRequestV2, CompletionsRequestV3], Body(discriminator="prompt_version"), ] @@ -146,7 +147,7 @@ async def completions( except Exception as e: log_exception(e) - request_log.debug( + request_log.info( "code completion input:", model_name=payload.model_name, model_provider=payload.model_provider, @@ -158,11 +159,14 @@ async def completions( ) kwargs = {} + if payload.model_provider == KindModelProvider.ANTHROPIC: - code_completions = completions_anthropic_factory() + code_completions = completions_anthropic_factory( + model__name=payload.model_name, + ) - # We support the prompt version 2 only with the Anthropic models - if payload.prompt_version == 2: + # We support the prompt version 3 only with the Anthropic models + if payload.prompt_version == 3: kwargs.update({"raw_prompt": payload.prompt}) elif payload.model_provider in ( KindModelProvider.LITELLM, diff --git a/ai_gateway/api/v2/code/typing.py b/ai_gateway/api/v2/code/typing.py index 14fbeb1d4108fe56acee913e3360dd8e51f7ef22..f811d9fad902e8da1193a3fc9cd3c354891b2640 100644 --- a/ai_gateway/api/v2/code/typing.py +++ b/ai_gateway/api/v2/code/typing.py @@ -24,6 +24,7 @@ __all__ = [ "CompletionsRequestV1", "GenerationsRequestV1", "CompletionsRequestV2", + "CompletionsRequestV3", "GenerationsRequestV2", "SuggestionsResponse", "StreamSuggestionsResponse", @@ -104,6 +105,11 @@ class CompletionsRequestV2(CompletionsRequest): prompt: Optional[str] = None +class CompletionsRequestV3(CompletionsRequest): + prompt_version: Literal[3] + prompt: Optional[list[Message]] = None + + class GenerationsRequestV2(GenerationsRequest): prompt_version: Literal[2] prompt: str diff --git a/ai_gateway/api/v3/code/completions.py b/ai_gateway/api/v3/code/completions.py index 858b8115c5216e258e03d1b1f15936a45879fd19..6fc5c19ea57c152fb11136786e11132d5626dd45 100644 --- a/ai_gateway/api/v3/code/completions.py +++ b/ai_gateway/api/v3/code/completions.py @@ -116,12 +116,15 @@ async def code_completion( code_context: list[CodeContextPayload] = None, snowplow_event_context: Optional[SnowplowEventContext] = None, ): + kwargs = {} + if payload.model_provider == ModelProvider.ANTHROPIC: - engine = completions_anthropic_factory() + # TODO: As we migrate to v3 we can rewrite this to use prompt registry + engine = completions_anthropic_factory(model__name=payload.model_name) + kwargs.update({"raw_prompt": payload.prompt}) else: engine = completions_legacy_factory() - kwargs = {} if payload.choices_count > 0: kwargs.update({"candidate_count": payload.choices_count}) diff --git a/ai_gateway/api/v3/code/typing.py b/ai_gateway/api/v3/code/typing.py index c195557901793fd44bf119693aafae656a87584d..14cf8a46af8a4e1aeb0ae6a27ef4270e8ce37592 100644 --- a/ai_gateway/api/v3/code/typing.py +++ b/ai_gateway/api/v3/code/typing.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, ConfigDict, Field, StringConstraints from starlette.responses import StreamingResponse from ai_gateway.code_suggestions import ModelProvider +from ai_gateway.models import Message __all__ = [ "CodeEditorComponents", @@ -46,6 +47,8 @@ class EditorContentPayload(BaseModel): class EditorContentCompletionPayload(EditorContentPayload): choices_count: Optional[int] = 0 + model_name: Optional[str] = None + prompt: Optional[str | list[Message]] = None class EditorContentGenerationPayload(EditorContentPayload): diff --git a/ai_gateway/code_suggestions/base.py b/ai_gateway/code_suggestions/base.py index 372aa5899f0098ced6b6d1aa72a960abf23baca8..1fb351f991dd021540aeb35513dd8d26a6ad74dc 100644 --- a/ai_gateway/code_suggestions/base.py +++ b/ai_gateway/code_suggestions/base.py @@ -51,6 +51,7 @@ USE_CASES_MODELS_MAP = { KindUseCase.CODE_COMPLETIONS: { KindAnthropicModel.CLAUDE_INSTANT_1_1, KindAnthropicModel.CLAUDE_INSTANT_1_2, + KindAnthropicModel.CLAUDE_3_5_SONNET, KindVertexTextModel.CODE_GECKO_002, KindVertexTextModel.CODESTRAL_2405, KindLiteLlmModel.CODEGEMMA, diff --git a/ai_gateway/code_suggestions/completions.py b/ai_gateway/code_suggestions/completions.py index ff216e31a2d4e40d24d07eabb9331ed7902649e4..04469e7772d6b5cae0fdb56dd787f3c9900add90 100644 --- a/ai_gateway/code_suggestions/completions.py +++ b/ai_gateway/code_suggestions/completions.py @@ -24,7 +24,7 @@ from ai_gateway.instrumentators import ( TextGenModelInstrumentator, benchmark, ) -from ai_gateway.models import ModelAPICallError, ModelAPIError +from ai_gateway.models import ChatModelBase, Message, ModelAPICallError, ModelAPIError from ai_gateway.models.agent_model import AgentModel from ai_gateway.models.base import TokensConsumptionMetadata from ai_gateway.models.base_text import ( @@ -143,7 +143,7 @@ class CodeCompletions: self, prefix: str, suffix: str, - raw_prompt: Optional[str] = None, + raw_prompt: Optional[str | list[Message]] = None, code_context: Optional[list] = None, context_max_percent: Optional[float] = None, ) -> Prompt: @@ -168,7 +168,7 @@ class CodeCompletions: suffix: str, file_name: str, editor_lang: Optional[str] = None, - raw_prompt: Optional[str] = None, + raw_prompt: Optional[str | list[Message]] = None, code_context: Optional[list] = None, stream: bool = False, snowplow_event_context: Optional[SnowplowEventContext] = None, @@ -196,6 +196,10 @@ class CodeCompletions: params = {"prefix": prompt.prefix, "suffix": prompt.suffix} res = await self.model.generate(params, stream) + elif isinstance(self.model, ChatModelBase): + res = await self.model.generate( + prompt.prefix, stream=stream, **kwargs + ) else: res = await self.model.generate( prompt.prefix, prompt.suffix, stream, **kwargs diff --git a/ai_gateway/code_suggestions/container.py b/ai_gateway/code_suggestions/container.py index 4d52a57fdeafbbe5ff2675e306b66b38d0f5b263..d66bb54d6a2e3d319322dc8f37131ee4a0942e98 100644 --- a/ai_gateway/code_suggestions/container.py +++ b/ai_gateway/code_suggestions/container.py @@ -101,6 +101,7 @@ class ContainerCodeCompletions(containers.DeclarativeContainer): tokenizer = providers.Dependency(instance_of=PreTrainedTokenizerFast) vertex_code_gecko = providers.Dependency(instance_of=TextGenModelBase) anthropic_claude = providers.Dependency(instance_of=TextGenModelBase) + anthropic_claude_chat = providers.Dependency(instance_of=ChatModelBase) litellm = providers.Dependency(instance_of=TextGenModelBase) agent_model = providers.Dependency(instance_of=TextGenModelBase) snowplow_instrumentator = providers.Dependency(instance_of=SnowplowInstrumentator) @@ -128,12 +129,7 @@ class ContainerCodeCompletions(containers.DeclarativeContainer): anthropic = providers.Factory( CodeCompletions, - model=providers.Factory( - anthropic_claude, - name=KindAnthropicModel.CLAUDE_INSTANT_1_2, - stop_sequences=["", anthropic.HUMAN_PROMPT], - max_tokens_to_sample=128, - ), + model=providers.Factory(anthropic_claude_chat), tokenization_strategy=providers.Factory( TokenizerTokenStrategy, tokenizer=tokenizer ), @@ -201,6 +197,7 @@ class ContainerCodeSuggestions(containers.DeclarativeContainer): tokenizer=tokenizer, vertex_code_gecko=models.vertex_code_gecko, anthropic_claude=models.anthropic_claude, + anthropic_claude_chat=models.anthropic_claude_chat, litellm=models.litellm, agent_model=models.agent_model, config=config, diff --git a/tests/api/v2/test_v2_code.py b/tests/api/v2/test_v2_code.py index c2c78f8baaa399f912c88ea5e0dbd09b008ca907..4b090e669eb84769e675a6ef1a9b3efbcc9cfeef 100644 --- a/tests/api/v2/test_v2_code.py +++ b/tests/api/v2/test_v2_code.py @@ -295,17 +295,17 @@ class TestCodeCompletions: }, False, ), - # prompt version 2 + # prompt version 3 ( - 2, + 3, "anthropic", - "claude-instant-1.2", + "claude-3-5-sonnet-20240620", "def search", { "id": "id", "model": { "engine": "anthropic", - "name": "claude-instant-1.2", + "name": "claude-3-5-sonnet-20240620", "lang": "python", "tokens_consumption_metadata": None, }, @@ -406,10 +406,23 @@ class TestCodeCompletions: } ) - if mock_suggestions_engine == "anthropic": - code_completions_kwargs.update( - {"raw_prompt": current_file["content_above_cursor"]} - ) + if prompt_version == 3 and mock_suggestions_engine == "anthropic": + raw_prompt = [ + { + "role": "system", + "content": "You are a code completion tool that performs Fill-in-the-middle. ", + }, + { + "role": "user", + "content": "\n// write a function to find the max\n\n\n\n\treturn min\n}\n')]", + }, + ] + data.update({"prompt": raw_prompt}) + raw_prompt = [ + Message(role=prompt["role"], content=prompt["content"]) + for prompt in raw_prompt + ] + code_completions_kwargs.update({"raw_prompt": raw_prompt}) response = mock_client.post( "/completions", @@ -457,6 +470,7 @@ class TestCodeCompletions: "project_id": 278964, "model_provider": "anthropic", "current_file": current_file, + "model_name": "claude-3-5-sonnet-20240620", } code_completions_kwargs = {} @@ -795,6 +809,7 @@ class TestCodeCompletions: "current_file": current_file, "stream": True, "context": context, + "model_name": "claude-3-5-sonnet-20240620", } if model: data["model_name"] = model @@ -994,6 +1009,7 @@ class TestCodeCompletions: "project_id": 278964, "model_provider": "anthropic", "current_file": current_file, + "model_name": "claude-3-5-sonnet-20240620", } response = mock_client.post( diff --git a/tests/api/v3/test_v3_code.py b/tests/api/v3/test_v3_code.py index ceef3c24e4e5bbe3f051c9ac9eebd6a59d7029cb..cc5b5a357545ea45407027b3c39a88afd0865f79 100644 --- a/tests/api/v3/test_v3_code.py +++ b/tests/api/v3/test_v3_code.py @@ -173,7 +173,7 @@ class TestEditorContentCompletion: }, { "engine": "anthropic", - "name": "claude-instant-1.2", + "name": "claude-3-5-sonnet-20240620", "lang": "python", }, ), @@ -208,7 +208,7 @@ class TestEditorContentCompletion: def test_model_provider( self, mock_client: TestClient, - mock_anthropic: Mock, + mock_anthropic_chat: Mock, mock_code_gecko: Mock, model_provider: str, expected_code: int, @@ -221,6 +221,17 @@ class TestEditorContentCompletion: "content_below_cursor": "\n", "language_identifier": "python", "model_provider": model_provider or None, + "model_name": "claude-3-5-sonnet-20240620", + "prompt": [ + { + "role": "system", + "content": "You are a code completion tool that performs Fill-in-the-middle", + }, + { + "role": "user", + "content": "\n// a function to find the max\n for \n\n\n\n\treturn min\n}\n", + }, + ], } prompt_component = { @@ -255,7 +266,7 @@ class TestEditorContentCompletion: assert body["metadata"]["model"] == expected_model - mock = mock_anthropic if model_provider == "anthropic" else mock_code_gecko + mock = mock_anthropic_chat if model_provider == "anthropic" else mock_code_gecko mock.assert_called @@ -272,6 +283,7 @@ class TestEditorContentCompletion: "language_identifier": "python", "model_provider": "anthropic", "stream": True, + "model_name": "claude-3-5-sonnet-20240620", } prompt_component = { @@ -319,6 +331,7 @@ class TestEditorContentCompletion: stream=True, code_context=None, snowplow_event_context=expected_snowplow_event, + raw_prompt=None, ) @@ -857,6 +870,7 @@ class TestIncomingRequest: "content_below_cursor": "", # FIXME: Forcing anthropic as vertex-ai is not working "model_provider": "anthropic", + "model_name": "claude-3-5-sonnet-20240620", }, }, ], @@ -873,6 +887,7 @@ class TestIncomingRequest: "file_name": "test", "content_above_cursor": "def hello_world():", "content_below_cursor": "", + "model_name": "claude-3-5-sonnet-20240620", }, }, ], @@ -889,6 +904,7 @@ class TestIncomingRequest: "file_name": "test", "content_above_cursor": "def hello_world():", "content_below_cursor": "", + "model_name": "claude-3-5-sonnet-20240620", }, }, ] @@ -905,6 +921,7 @@ class TestIncomingRequest: "payload": { "content_above_cursor": "def hello_world():", "content_below_cursor": "", + "model_name": "claude-3-5-sonnet-20240620", }, }, ], diff --git a/tests/code_suggestions/test_container.py b/tests/code_suggestions/test_container.py index b7ee620110637b5f184ed44d67a626f9b4dfee12..2cb7d7ba74f45f8316b66999006398650e697b5c 100644 --- a/tests/code_suggestions/test_container.py +++ b/tests/code_suggestions/test_container.py @@ -15,7 +15,10 @@ def test_container(mock_container: containers.DeclarativeContainer): generations = code_suggestions.generations assert isinstance(completions.vertex_legacy(), CodeCompletionsLegacy) - assert isinstance(completions.anthropic(), CodeCompletions) + assert isinstance( + completions.anthropic(model__name=KindAnthropicModel.CLAUDE_3_5_SONNET), + CodeCompletions, + ) assert isinstance( completions.litellm_factory(model__name=KindLiteLlmModel.CODEGEMMA), CodeCompletions,