Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/packages/kagent-adk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ dependencies = [
"protobuf>=6.33.5", # CVE-2026-0994: Denial of Service due to recursion depth bypass
"anthropic[vertex]>=0.49.0",
"fastapi>=0.115.1",
"litellm>=1.81.0,<2.0", # CVE-2025-45809: SQL injection in earlier versions; pin above all known affected releases
"google-adk>=1.25.0",
"google-genai>=1.21.1",
"google-auth>=2.40.2",
Expand Down
137 changes: 101 additions & 36 deletions python/packages/kagent-adk/src/kagent/adk/_memory_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from google.adk.models import BaseLlm
from google.adk.sessions import Session
from google.genai import types
from litellm import aembedding

from kagent.adk.types import EmbeddingConfig

Expand Down Expand Up @@ -303,7 +302,7 @@ def _normalize_l2(self, x):
async def _generate_embedding_async(
self, input_data: Union[str, List[str]]
) -> Union[List[float], List[List[float]]]:
"""Generate embedding vector(s) using LiteLLM.
"""Generate embedding vector(s) using provider-specific SDK clients.

Args:
input_data: Single string or list of strings to embed.
Expand All @@ -324,48 +323,114 @@ async def _generate_embedding_async(
logger.warning("No embedding model specified in config")
return []

# Build LiteLLM model identifier
litellm_model = model_name
if provider and provider != "openai" and "/" not in model_name:
if provider == "azure_openai":
litellm_model = f"azure/{model_name}"
elif provider == "ollama":
litellm_model = f"ollama/{model_name}"
elif provider == "vertex_ai":
litellm_model = f"vertex_ai/{model_name}"
elif provider == "gemini":
litellm_model = f"gemini/{model_name}"
is_batch = isinstance(input_data, list)
texts = input_data if is_batch else [input_data]
api_base = self.embedding_config.base_url or None

try:
is_batch = isinstance(input_data, list)
texts = input_data if is_batch else [input_data]
raw_embeddings = await self._call_embedding_provider(provider, model_name, texts, api_base)
except Exception as e:
logger.error("Error generating embedding with provider=%s model=%s: %s", provider, model_name, e)
return []

# Most Matryoshka Representation Learning embedding models produce embeddings that still have
# meaning when truncated to specific sizes: https://huggingface.co/blog/matryoshka
# We must ensure embeddings have consistent dimensions for the vector storage backend.
embeddings = []
for embedding in raw_embeddings:
if len(embedding) > 768:
embedding = embedding[:768]
embedding = self._normalize_l2(embedding)
Comment thread
jmhbh marked this conversation as resolved.
Outdated
embeddings.append(embedding)

if is_batch:
return embeddings
return embeddings[0] if embeddings else []

async def _call_embedding_provider(
self,
provider: str,
model_name: str,
texts: List[str],
api_base: Optional[str],
) -> List[List[float]]:
"""Dispatch to the correct provider SDK for embedding generation."""
if provider in ("openai", "azure_openai"):
return await self._embed_openai(provider, model_name, texts, api_base)
if provider == "ollama":
return await self._embed_ollama(model_name, texts, api_base)
if provider in ("vertex_ai", "gemini"):
return await self._embed_google(provider, model_name, texts)
# Unknown provider — try OpenAI-compatible as a fallback
logger.warning("Unknown embedding provider '%s'; attempting OpenAI-compatible call.", provider)
return await self._embed_openai("openai", model_name, texts, api_base)

async def _embed_openai(
self,
provider: str,
model_name: str,
texts: List[str],
api_base: Optional[str],
) -> List[List[float]]:
"""Embed using the OpenAI or Azure OpenAI SDK."""
import os

if provider == "azure_openai":
from openai import AsyncAzureOpenAI

api_version = os.environ.get("OPENAI_API_VERSION", "2024-02-15-preview")
azure_endpoint = api_base or os.environ.get("AZURE_OPENAI_ENDPOINT")
if not azure_endpoint:
raise ValueError("Azure OpenAI endpoint must be set via base_url or AZURE_OPENAI_ENDPOINT env var")
client = AsyncAzureOpenAI(api_version=api_version, azure_endpoint=azure_endpoint)
else:
from openai import AsyncOpenAI

# Most Matryoshka Representation Learning embedding models produce embeddings that still have meaning when truncated to specific sizes
# https://huggingface.co/blog/matryoshka
# We must ensure that embeddings have proper dimensions for compatibility with vector storage backend
api_base = self.embedding_config.base_url or None
response = await aembedding(model=litellm_model, input=texts, dimensions=768, api_base=api_base)
client = AsyncOpenAI(base_url=api_base or None)

embeddings = []
for item in response.data:
embedding = item["embedding"]
response = await client.embeddings.create(model=model_name, input=texts, dimensions=768)
return [item.embedding for item in response.data]

# LiteLLM does not truncate embeddings by default if the model doesn't support it
# However, truncating embeddings is still valid (for most models, see OpenAI's docs and this research https://arxiv.org/html/2508.17744v1)
if len(embedding) > 768:
embedding = embedding[:768]
# if we change dimension manually, we need to re-normalize the embeddings
embedding = self._normalize_l2(embedding)
async def _embed_ollama(
self,
model_name: str,
texts: List[str],
api_base: Optional[str],
) -> List[List[float]]:
"""Embed using the Ollama SDK."""
import os

embeddings.append(embedding)
import ollama

if is_batch:
return embeddings
return embeddings[0] if embeddings else []
host = api_base or os.environ.get("OLLAMA_API_BASE", "http://localhost:11434")
client = ollama.AsyncClient(host=host)
result = await client.embed(model=model_name, input=texts)
return list(result.embeddings)

except Exception as e:
logger.error("Error generating embedding with model %s: %s", litellm_model, e)
return []
async def _embed_google(
self,
provider: str,
model_name: str,
texts: List[str],
) -> List[List[float]]:
"""Embed using google-genai (Gemini or Vertex AI)."""
from google import genai
from google.genai import types as genai_types

if provider == "vertex_ai":
client = genai.Client(vertexai=True)
else:
client = genai.Client()

embeddings = []
for text in texts:
Comment thread
jmhbh marked this conversation as resolved.
Outdated
response = client.models.embed_content(
Comment thread
jmhbh marked this conversation as resolved.
Outdated
model=model_name,
contents=text,
config=genai_types.EmbedContentConfig(output_dimensionality=768),
)
embeddings.append(list(response.embeddings[0].values))
return embeddings

async def _summarize_session_content_async(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from ._litellm import KAgentLiteLlm
from ._anthropic import KAgentAnthropicLlm
from ._bedrock import KAgentBedrockLlm
from ._ollama import KAgentOllamaLlm
from ._openai import AzureOpenAI, OpenAI

__all__ = ["OpenAI", "AzureOpenAI", "KAgentLiteLlm"]
__all__ = ["OpenAI", "AzureOpenAI", "KAgentAnthropicLlm", "KAgentBedrockLlm", "KAgentOllamaLlm"]
43 changes: 43 additions & 0 deletions python/packages/kagent-adk/src/kagent/adk/models/_anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Anthropic model implementation with api_key_passthrough, base_url, and header support."""

from __future__ import annotations

import logging
import os
from functools import cached_property
from typing import Optional

from anthropic import AsyncAnthropic
from google.adk.models.anthropic_llm import AnthropicLlm

logger = logging.getLogger(__name__)


class KAgentAnthropicLlm(AnthropicLlm):
"""Anthropic model with api_key_passthrough, custom base_url, and header support."""

api_key_passthrough: Optional[bool] = None

_api_key: Optional[str] = None
base_url: Optional[str] = None
extra_headers: Optional[dict[str, str]] = None

model_config = {"arbitrary_types_allowed": True}

def set_passthrough_key(self, token: str) -> None:
"""Forward the Bearer token from the incoming A2A request as the Anthropic API key."""
self._api_key = token
# Invalidate cached client so it's recreated with the new key
self.__dict__.pop("_anthropic_client", None)

@cached_property
def _anthropic_client(self) -> AsyncAnthropic:
api_key = self._api_key or os.environ.get("ANTHROPIC_API_KEY")
kwargs = {}
if api_key:
kwargs["api_key"] = api_key
if self.base_url:
kwargs["base_url"] = self.base_url
if self.extra_headers:
kwargs["default_headers"] = self.extra_headers
return AsyncAnthropic(**kwargs)
Loading
Loading