diff --git a/ai_gateway/api/v2/code/completions.py b/ai_gateway/api/v2/code/completions.py
index d6a0b266cba74638f6e703dd03c93efbefd00e9d..b01bf5966646742d926718b9d5c93e89c1a793d1 100644
--- a/ai_gateway/api/v2/code/completions.py
+++ b/ai_gateway/api/v2/code/completions.py
@@ -13,6 +13,7 @@ from ai_gateway.api.snowplow_context import get_snowplow_code_suggestion_context
from ai_gateway.api.v2.code.typing import (
CompletionsRequestV1,
CompletionsRequestV2,
+ CompletionsRequestV3,
GenerationsRequestV1,
GenerationsRequestV2,
GenerationsRequestV3,
@@ -73,7 +74,7 @@ request_log = get_request_logger("codesuggestions")
router = APIRouter()
CompletionsRequestWithVersion = Annotated[
- Union[CompletionsRequestV1, CompletionsRequestV2],
+ Union[CompletionsRequestV1, CompletionsRequestV2, CompletionsRequestV3],
Body(discriminator="prompt_version"),
]
@@ -146,7 +147,7 @@ async def completions(
except Exception as e:
log_exception(e)
- request_log.debug(
+ request_log.info(
"code completion input:",
model_name=payload.model_name,
model_provider=payload.model_provider,
@@ -158,11 +159,14 @@ async def completions(
)
kwargs = {}
+
if payload.model_provider == KindModelProvider.ANTHROPIC:
- code_completions = completions_anthropic_factory()
+ code_completions = completions_anthropic_factory(
+ model__name=payload.model_name,
+ )
- # We support the prompt version 2 only with the Anthropic models
- if payload.prompt_version == 2:
+ # We support the prompt version 3 only with the Anthropic models
+ if payload.prompt_version == 3:
kwargs.update({"raw_prompt": payload.prompt})
elif payload.model_provider in (
KindModelProvider.LITELLM,
diff --git a/ai_gateway/api/v2/code/typing.py b/ai_gateway/api/v2/code/typing.py
index 14fbeb1d4108fe56acee913e3360dd8e51f7ef22..f811d9fad902e8da1193a3fc9cd3c354891b2640 100644
--- a/ai_gateway/api/v2/code/typing.py
+++ b/ai_gateway/api/v2/code/typing.py
@@ -24,6 +24,7 @@ __all__ = [
"CompletionsRequestV1",
"GenerationsRequestV1",
"CompletionsRequestV2",
+ "CompletionsRequestV3",
"GenerationsRequestV2",
"SuggestionsResponse",
"StreamSuggestionsResponse",
@@ -104,6 +105,11 @@ class CompletionsRequestV2(CompletionsRequest):
prompt: Optional[str] = None
+class CompletionsRequestV3(CompletionsRequest):
+ prompt_version: Literal[3]
+ prompt: Optional[list[Message]] = None
+
+
class GenerationsRequestV2(GenerationsRequest):
prompt_version: Literal[2]
prompt: str
diff --git a/ai_gateway/api/v3/code/completions.py b/ai_gateway/api/v3/code/completions.py
index 858b8115c5216e258e03d1b1f15936a45879fd19..6fc5c19ea57c152fb11136786e11132d5626dd45 100644
--- a/ai_gateway/api/v3/code/completions.py
+++ b/ai_gateway/api/v3/code/completions.py
@@ -116,12 +116,15 @@ async def code_completion(
code_context: list[CodeContextPayload] = None,
snowplow_event_context: Optional[SnowplowEventContext] = None,
):
+ kwargs = {}
+
if payload.model_provider == ModelProvider.ANTHROPIC:
- engine = completions_anthropic_factory()
+ # TODO: As we migrate to v3 we can rewrite this to use prompt registry
+ engine = completions_anthropic_factory(model__name=payload.model_name)
+ kwargs.update({"raw_prompt": payload.prompt})
else:
engine = completions_legacy_factory()
- kwargs = {}
if payload.choices_count > 0:
kwargs.update({"candidate_count": payload.choices_count})
diff --git a/ai_gateway/api/v3/code/typing.py b/ai_gateway/api/v3/code/typing.py
index c195557901793fd44bf119693aafae656a87584d..14cf8a46af8a4e1aeb0ae6a27ef4270e8ce37592 100644
--- a/ai_gateway/api/v3/code/typing.py
+++ b/ai_gateway/api/v3/code/typing.py
@@ -6,6 +6,7 @@ from pydantic import BaseModel, ConfigDict, Field, StringConstraints
from starlette.responses import StreamingResponse
from ai_gateway.code_suggestions import ModelProvider
+from ai_gateway.models import Message
__all__ = [
"CodeEditorComponents",
@@ -46,6 +47,8 @@ class EditorContentPayload(BaseModel):
class EditorContentCompletionPayload(EditorContentPayload):
choices_count: Optional[int] = 0
+ model_name: Optional[str] = None
+ prompt: Optional[str | list[Message]] = None
class EditorContentGenerationPayload(EditorContentPayload):
diff --git a/ai_gateway/code_suggestions/base.py b/ai_gateway/code_suggestions/base.py
index 372aa5899f0098ced6b6d1aa72a960abf23baca8..1fb351f991dd021540aeb35513dd8d26a6ad74dc 100644
--- a/ai_gateway/code_suggestions/base.py
+++ b/ai_gateway/code_suggestions/base.py
@@ -51,6 +51,7 @@ USE_CASES_MODELS_MAP = {
KindUseCase.CODE_COMPLETIONS: {
KindAnthropicModel.CLAUDE_INSTANT_1_1,
KindAnthropicModel.CLAUDE_INSTANT_1_2,
+ KindAnthropicModel.CLAUDE_3_5_SONNET,
KindVertexTextModel.CODE_GECKO_002,
KindVertexTextModel.CODESTRAL_2405,
KindLiteLlmModel.CODEGEMMA,
diff --git a/ai_gateway/code_suggestions/completions.py b/ai_gateway/code_suggestions/completions.py
index ff216e31a2d4e40d24d07eabb9331ed7902649e4..04469e7772d6b5cae0fdb56dd787f3c9900add90 100644
--- a/ai_gateway/code_suggestions/completions.py
+++ b/ai_gateway/code_suggestions/completions.py
@@ -24,7 +24,7 @@ from ai_gateway.instrumentators import (
TextGenModelInstrumentator,
benchmark,
)
-from ai_gateway.models import ModelAPICallError, ModelAPIError
+from ai_gateway.models import ChatModelBase, Message, ModelAPICallError, ModelAPIError
from ai_gateway.models.agent_model import AgentModel
from ai_gateway.models.base import TokensConsumptionMetadata
from ai_gateway.models.base_text import (
@@ -143,7 +143,7 @@ class CodeCompletions:
self,
prefix: str,
suffix: str,
- raw_prompt: Optional[str] = None,
+ raw_prompt: Optional[str | list[Message]] = None,
code_context: Optional[list] = None,
context_max_percent: Optional[float] = None,
) -> Prompt:
@@ -168,7 +168,7 @@ class CodeCompletions:
suffix: str,
file_name: str,
editor_lang: Optional[str] = None,
- raw_prompt: Optional[str] = None,
+ raw_prompt: Optional[str | list[Message]] = None,
code_context: Optional[list] = None,
stream: bool = False,
snowplow_event_context: Optional[SnowplowEventContext] = None,
@@ -196,6 +196,10 @@ class CodeCompletions:
params = {"prefix": prompt.prefix, "suffix": prompt.suffix}
res = await self.model.generate(params, stream)
+ elif isinstance(self.model, ChatModelBase):
+ res = await self.model.generate(
+ prompt.prefix, stream=stream, **kwargs
+ )
else:
res = await self.model.generate(
prompt.prefix, prompt.suffix, stream, **kwargs
diff --git a/ai_gateway/code_suggestions/container.py b/ai_gateway/code_suggestions/container.py
index 4d52a57fdeafbbe5ff2675e306b66b38d0f5b263..d66bb54d6a2e3d319322dc8f37131ee4a0942e98 100644
--- a/ai_gateway/code_suggestions/container.py
+++ b/ai_gateway/code_suggestions/container.py
@@ -101,6 +101,7 @@ class ContainerCodeCompletions(containers.DeclarativeContainer):
tokenizer = providers.Dependency(instance_of=PreTrainedTokenizerFast)
vertex_code_gecko = providers.Dependency(instance_of=TextGenModelBase)
anthropic_claude = providers.Dependency(instance_of=TextGenModelBase)
+ anthropic_claude_chat = providers.Dependency(instance_of=ChatModelBase)
litellm = providers.Dependency(instance_of=TextGenModelBase)
agent_model = providers.Dependency(instance_of=TextGenModelBase)
snowplow_instrumentator = providers.Dependency(instance_of=SnowplowInstrumentator)
@@ -128,12 +129,7 @@ class ContainerCodeCompletions(containers.DeclarativeContainer):
anthropic = providers.Factory(
CodeCompletions,
- model=providers.Factory(
- anthropic_claude,
- name=KindAnthropicModel.CLAUDE_INSTANT_1_2,
- stop_sequences=["", anthropic.HUMAN_PROMPT],
- max_tokens_to_sample=128,
- ),
+ model=providers.Factory(anthropic_claude_chat),
tokenization_strategy=providers.Factory(
TokenizerTokenStrategy, tokenizer=tokenizer
),
@@ -201,6 +197,7 @@ class ContainerCodeSuggestions(containers.DeclarativeContainer):
tokenizer=tokenizer,
vertex_code_gecko=models.vertex_code_gecko,
anthropic_claude=models.anthropic_claude,
+ anthropic_claude_chat=models.anthropic_claude_chat,
litellm=models.litellm,
agent_model=models.agent_model,
config=config,
diff --git a/tests/api/v2/test_v2_code.py b/tests/api/v2/test_v2_code.py
index c2c78f8baaa399f912c88ea5e0dbd09b008ca907..4b090e669eb84769e675a6ef1a9b3efbcc9cfeef 100644
--- a/tests/api/v2/test_v2_code.py
+++ b/tests/api/v2/test_v2_code.py
@@ -295,17 +295,17 @@ class TestCodeCompletions:
},
False,
),
- # prompt version 2
+ # prompt version 3
(
- 2,
+ 3,
"anthropic",
- "claude-instant-1.2",
+ "claude-3-5-sonnet-20240620",
"def search",
{
"id": "id",
"model": {
"engine": "anthropic",
- "name": "claude-instant-1.2",
+ "name": "claude-3-5-sonnet-20240620",
"lang": "python",
"tokens_consumption_metadata": None,
},
@@ -406,10 +406,23 @@ class TestCodeCompletions:
}
)
- if mock_suggestions_engine == "anthropic":
- code_completions_kwargs.update(
- {"raw_prompt": current_file["content_above_cursor"]}
- )
+ if prompt_version == 3 and mock_suggestions_engine == "anthropic":
+ raw_prompt = [
+ {
+ "role": "system",
+ "content": "You are a code completion tool that performs Fill-in-the-middle. ",
+ },
+ {
+ "role": "user",
+ "content": "\n// write a function to find the max\n\n\n\n\treturn min\n}\n')]",
+ },
+ ]
+ data.update({"prompt": raw_prompt})
+ raw_prompt = [
+ Message(role=prompt["role"], content=prompt["content"])
+ for prompt in raw_prompt
+ ]
+ code_completions_kwargs.update({"raw_prompt": raw_prompt})
response = mock_client.post(
"/completions",
@@ -457,6 +470,7 @@ class TestCodeCompletions:
"project_id": 278964,
"model_provider": "anthropic",
"current_file": current_file,
+ "model_name": "claude-3-5-sonnet-20240620",
}
code_completions_kwargs = {}
@@ -795,6 +809,7 @@ class TestCodeCompletions:
"current_file": current_file,
"stream": True,
"context": context,
+ "model_name": "claude-3-5-sonnet-20240620",
}
if model:
data["model_name"] = model
@@ -994,6 +1009,7 @@ class TestCodeCompletions:
"project_id": 278964,
"model_provider": "anthropic",
"current_file": current_file,
+ "model_name": "claude-3-5-sonnet-20240620",
}
response = mock_client.post(
diff --git a/tests/api/v3/test_v3_code.py b/tests/api/v3/test_v3_code.py
index ceef3c24e4e5bbe3f051c9ac9eebd6a59d7029cb..cc5b5a357545ea45407027b3c39a88afd0865f79 100644
--- a/tests/api/v3/test_v3_code.py
+++ b/tests/api/v3/test_v3_code.py
@@ -173,7 +173,7 @@ class TestEditorContentCompletion:
},
{
"engine": "anthropic",
- "name": "claude-instant-1.2",
+ "name": "claude-3-5-sonnet-20240620",
"lang": "python",
},
),
@@ -208,7 +208,7 @@ class TestEditorContentCompletion:
def test_model_provider(
self,
mock_client: TestClient,
- mock_anthropic: Mock,
+ mock_anthropic_chat: Mock,
mock_code_gecko: Mock,
model_provider: str,
expected_code: int,
@@ -221,6 +221,17 @@ class TestEditorContentCompletion:
"content_below_cursor": "\n",
"language_identifier": "python",
"model_provider": model_provider or None,
+ "model_name": "claude-3-5-sonnet-20240620",
+ "prompt": [
+ {
+ "role": "system",
+ "content": "You are a code completion tool that performs Fill-in-the-middle",
+ },
+ {
+ "role": "user",
+ "content": "\n// a function to find the max\n for \n\n\n\n\treturn min\n}\n",
+ },
+ ],
}
prompt_component = {
@@ -255,7 +266,7 @@ class TestEditorContentCompletion:
assert body["metadata"]["model"] == expected_model
- mock = mock_anthropic if model_provider == "anthropic" else mock_code_gecko
+ mock = mock_anthropic_chat if model_provider == "anthropic" else mock_code_gecko
mock.assert_called
@@ -272,6 +283,7 @@ class TestEditorContentCompletion:
"language_identifier": "python",
"model_provider": "anthropic",
"stream": True,
+ "model_name": "claude-3-5-sonnet-20240620",
}
prompt_component = {
@@ -319,6 +331,7 @@ class TestEditorContentCompletion:
stream=True,
code_context=None,
snowplow_event_context=expected_snowplow_event,
+ raw_prompt=None,
)
@@ -857,6 +870,7 @@ class TestIncomingRequest:
"content_below_cursor": "",
# FIXME: Forcing anthropic as vertex-ai is not working
"model_provider": "anthropic",
+ "model_name": "claude-3-5-sonnet-20240620",
},
},
],
@@ -873,6 +887,7 @@ class TestIncomingRequest:
"file_name": "test",
"content_above_cursor": "def hello_world():",
"content_below_cursor": "",
+ "model_name": "claude-3-5-sonnet-20240620",
},
},
],
@@ -889,6 +904,7 @@ class TestIncomingRequest:
"file_name": "test",
"content_above_cursor": "def hello_world():",
"content_below_cursor": "",
+ "model_name": "claude-3-5-sonnet-20240620",
},
},
]
@@ -905,6 +921,7 @@ class TestIncomingRequest:
"payload": {
"content_above_cursor": "def hello_world():",
"content_below_cursor": "",
+ "model_name": "claude-3-5-sonnet-20240620",
},
},
],
diff --git a/tests/code_suggestions/test_container.py b/tests/code_suggestions/test_container.py
index b7ee620110637b5f184ed44d67a626f9b4dfee12..2cb7d7ba74f45f8316b66999006398650e697b5c 100644
--- a/tests/code_suggestions/test_container.py
+++ b/tests/code_suggestions/test_container.py
@@ -15,7 +15,10 @@ def test_container(mock_container: containers.DeclarativeContainer):
generations = code_suggestions.generations
assert isinstance(completions.vertex_legacy(), CodeCompletionsLegacy)
- assert isinstance(completions.anthropic(), CodeCompletions)
+ assert isinstance(
+ completions.anthropic(model__name=KindAnthropicModel.CLAUDE_3_5_SONNET),
+ CodeCompletions,
+ )
assert isinstance(
completions.litellm_factory(model__name=KindLiteLlmModel.CODEGEMMA),
CodeCompletions,