Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(wren-ai-service): Try to make evaluation work again #1085

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 144 additions & 38 deletions wren-ai-service/eval/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,35 @@
import os
import re
import sys
import uuid
from abc import abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Literal

import orjson
import json
from haystack import Document
from langfuse.decorators import langfuse_context, observe
from tqdm.asyncio import tqdm_asyncio
from src.config import settings
from src.providers import generate_components
from src.web.v1.services.semantics_preparation import (
SemanticsPreparationRequest,
SemanticsPreparationService,
)
from src.web.v1.services.ask import (
AskRequest,
AskResultRequest,
AskResultResponse,
AskService,
)
from src.pipelines.generation import (
data_assistance,
intent_classification,
sql_correction,
sql_generation,
)
from src.pipelines.retrieval import historical_question, retrieval

sys.path.append(f"{Path().parent.resolve()}")

Expand All @@ -32,15 +53,15 @@
from src.core.engine import Engine
from src.core.provider import DocumentStoreProvider, EmbedderProvider, LLMProvider
from src.pipelines.generation import sql_generation
from src.pipelines.indexing import indexing
from src.pipelines.retrieval import retrieval
from src.pipelines import indexing


def deploy_model(mdl: str, pipe: indexing.Indexing) -> None:
async def wrapper():
await pipe.run(orjson.dumps(mdl).decode())
# def deploy_model(mdl: str, pipe: indexing.Indexing) -> None:
# async def wrapper():
# await pipe.run(orjson.dumps(mdl).decode())

asyncio.run(wrapper())
# asyncio.run(wrapper())


def extract_units(docs: list) -> list:
Expand Down Expand Up @@ -107,6 +128,7 @@ def split(queries: list, batch_size: int) -> list[list]:
]

async def wrapper(batch: list):
# self() will call sub-class's __call__ in every service
tasks = [self(query) for query in batch]
results = await tqdm_asyncio.gather(*tasks, desc="Generating Predictions")
await asyncio.sleep(self._batch_interval)
Expand Down Expand Up @@ -188,7 +210,7 @@ def __init__(
embedder_provider=embedder_provider,
document_store_provider=document_store_provider,
)
deploy_model(mdl, _indexing)
# deploy_model(mdl, _indexing)

self._retrieval = retrieval.Retrieval(
llm_provider=llm_provider,
Expand Down Expand Up @@ -288,36 +310,82 @@ def mertics(
}



class AskPipeline(Eval):
def indexing_service(self):

return SemanticsPreparationService(
{
"db_schema": indexing.DBSchema(
**self.pipe_components["db_schema_indexing"],
),
"historical_question": indexing.HistoricalQuestion(
**self.pipe_components["historical_question_indexing"],
),
"table_description": indexing.TableDescription(
**self.pipe_components["table_description_indexing"],
),
}
)
Comment on lines +315 to +329
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for missing components.

The function accesses self.pipe_components dictionary without checking if required components exist. This could raise KeyError exceptions if any required component is missing.

Add error handling to safely access required components:

     def indexing_service(self):
+        required_components = [
+            "db_schema_indexing",
+            "historical_question_indexing",
+            "table_description_indexing"
+        ]
+        missing_components = [comp for comp in required_components if comp not in self.pipe_components]
+        if missing_components:
+            raise ValueError(f"Missing required components: {missing_components}")
+
         return SemanticsPreparationService(
             {
                 "db_schema": indexing.DBSchema(
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def indexing_service(self):
return SemanticsPreparationService(
{
"db_schema": indexing.DBSchema(
**self.pipe_components["db_schema_indexing"],
),
"historical_question": indexing.HistoricalQuestion(
**self.pipe_components["historical_question_indexing"],
),
"table_description": indexing.TableDescription(
**self.pipe_components["table_description_indexing"],
),
}
)
def indexing_service(self):
required_components = [
"db_schema_indexing",
"historical_question_indexing",
"table_description_indexing"
]
missing_components = [comp for comp in required_components if comp not in self.pipe_components]
if missing_components:
raise ValueError(f"Missing required components: {missing_components}")
return SemanticsPreparationService(
{
"db_schema": indexing.DBSchema(
**self.pipe_components["db_schema_indexing"],
),
"historical_question": indexing.HistoricalQuestion(
**self.pipe_components["historical_question_indexing"],
),
"table_description": indexing.TableDescription(
**self.pipe_components["table_description_indexing"],
),
}
)


def ask_service(self):

return AskService(
{
"intent_classification": intent_classification.IntentClassification(
**self.pipe_components["intent_classification"],
),
"data_assistance": data_assistance.DataAssistance(
**self.pipe_components["data_assistance"],
),
"retrieval": retrieval.Retrieval(
**self.pipe_components["db_schema_retrieval"],
),
"historical_question": historical_question.HistoricalQuestion(
**self.pipe_components["historical_question_retrieval"],
),
"sql_generation": sql_generation.SQLGeneration(
**self.pipe_components["sql_generation"],
),
"sql_correction": sql_correction.SQLCorrection(
**self.pipe_components["sql_correction"],
),
}
)
Comment on lines +331 to +354
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for missing components.

Similar to indexing_service, this function accesses self.pipe_components dictionary without checking if required components exist.

Add error handling to safely access required components:

     def ask_service(self):
+        required_components = [
+            "intent_classification",
+            "data_assistance",
+            "db_schema_retrieval",
+            "historical_question_retrieval",
+            "sql_generation",
+            "sql_correction"
+        ]
+        missing_components = [comp for comp in required_components if comp not in self.pipe_components]
+        if missing_components:
+            raise ValueError(f"Missing required components: {missing_components}")
+
         return AskService(
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def ask_service(self):
return AskService(
{
"intent_classification": intent_classification.IntentClassification(
**self.pipe_components["intent_classification"],
),
"data_assistance": data_assistance.DataAssistance(
**self.pipe_components["data_assistance"],
),
"retrieval": retrieval.Retrieval(
**self.pipe_components["db_schema_retrieval"],
),
"historical_question": historical_question.HistoricalQuestion(
**self.pipe_components["historical_question_retrieval"],
),
"sql_generation": sql_generation.SQLGeneration(
**self.pipe_components["sql_generation"],
),
"sql_correction": sql_correction.SQLCorrection(
**self.pipe_components["sql_correction"],
),
}
)
def ask_service(self):
required_components = [
"intent_classification",
"data_assistance",
"db_schema_retrieval",
"historical_question_retrieval",
"sql_generation",
"sql_correction"
]
missing_components = [comp for comp in required_components if comp not in self.pipe_components]
if missing_components:
raise ValueError(f"Missing required components: {missing_components}")
return AskService(
{
"intent_classification": intent_classification.IntentClassification(
**self.pipe_components["intent_classification"],
),
"data_assistance": data_assistance.DataAssistance(
**self.pipe_components["data_assistance"],
),
"retrieval": retrieval.Retrieval(
**self.pipe_components["db_schema_retrieval"],
),
"historical_question": historical_question.HistoricalQuestion(
**self.pipe_components["historical_question_retrieval"],
),
"sql_generation": sql_generation.SQLGeneration(
**self.pipe_components["sql_generation"],
),
"sql_correction": sql_correction.SQLCorrection(
**self.pipe_components["sql_correction"],
),
}
)

def dict_to_string(self, d: dict) -> str:
if not isinstance(d, dict):
return str(d)

result = "{"
for key, value in d.items():
result += f"'{key}': {self.dict_to_string(value)}, "
result = result.rstrip(", ") + "}"
return result
Comment on lines +355 to +363
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for circular references and unhashable keys.

The recursive dictionary to string conversion could fail in several scenarios:

  1. Circular references could cause infinite recursion
  2. Unhashable dictionary keys could raise TypeError

Add error handling to handle these cases:

     def dict_to_string(self, d: dict, seen=None) -> str:
+        if seen is None:
+            seen = set()
+
         if not isinstance(d, dict):
             return str(d)
 
+        # Check for circular references
+        d_id = id(d)
+        if d_id in seen:
+            return "{...}"  # Indicate circular reference
+        seen.add(d_id)
+
         result = "{"
-        for key, value in d.items():
-            result += f"'{key}': {self.dict_to_string(value)}, "
+        try:
+            for key, value in d.items():
+                result += f"'{key}': {self.dict_to_string(value, seen)}, "
+        except TypeError as e:
+            return f"{{Error: {str(e)}}}"
+
         result = result.rstrip(", ") + "}"
+        seen.remove(d_id)
         return result
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def dict_to_string(self, d: dict) -> str:
if not isinstance(d, dict):
return str(d)
result = "{"
for key, value in d.items():
result += f"'{key}': {self.dict_to_string(value)}, "
result = result.rstrip(", ") + "}"
return result
def dict_to_string(self, d: dict, seen=None) -> str:
if seen is None:
seen = set()
if not isinstance(d, dict):
return str(d)
# Check for circular references
d_id = id(d)
if d_id in seen:
return "{...}" # Indicate circular reference
seen.add(d_id)
result = "{"
try:
for key, value in d.items():
result += f"'{key}': {self.dict_to_string(value, seen)}, "
except TypeError as e:
return f"{{Error: {str(e)}}}"
result = result.rstrip(", ") + "}"
seen.remove(d_id)
return result


def __init__(
self,
meta: dict,
mdl: dict,
llm_provider: LLMProvider,
embedder_provider: EmbedderProvider,
document_store_provider: DocumentStoreProvider,
engine: Engine,
**kwargs,
service_metadata,
pipe_components,
):
super().__init__(meta, 3)

document_store_provider.get_store(recreate_index=True)
_indexing = indexing.Indexing(
embedder_provider=embedder_provider,
document_store_provider=document_store_provider,
)
deploy_model(mdl, _indexing)

self.service_metadata = service_metadata

# document_store_provider.get_store(recreate_index=True)
# _indexing = indexing.Indexing(
# embedder_provider=embedder_provider,
# document_store_provider=document_store_provider,
# )
# deploy_model(mdl, _indexing)
self.pipe_components = pipe_components
self.project_id = str(uuid.uuid4().int >> 65)
self.indexing_service_var = self.indexing_service()
self.mdl_str_var = json.dumps(mdl)
self.ask_service_var = self.ask_service()
self.service_metadata = service_metadata
self._mdl = mdl
self._retrieval = retrieval.Retrieval(
llm_provider=llm_provider,
embedder_provider=embedder_provider,
document_store_provider=document_store_provider,
)
self._generation = sql_generation.SQLGeneration(
llm_provider=llm_provider,
engine=engine,
)
self.mdl_hash = str(hash(self.mdl_str_var))

async def _flat(self, prediction: dict, actual: str) -> dict:
prediction["actual_output"] = actual
Expand All @@ -327,17 +395,54 @@ async def _flat(self, prediction: dict, actual: str) -> dict:
return prediction

async def _process(self, prediction: dict, **_) -> dict:
result = await self._retrieval.run(query=prediction["input"])
documents = result.get("construct_retrieval_results", [])
actual_output = await self._generation.run(

await self.indexing_service_var.prepare_semantics(
SemanticsPreparationRequest(
mdl=self.mdl_str_var,
mdl_hash=self.mdl_hash,
project_id=self.project_id
),
service_metadata=self.service_metadata,
)

# asking
ask_request = AskRequest(
query=prediction["input"],
contexts=documents,
samples=prediction["samples"],
exclude=[],
mdl_hash=self.mdl_hash,
project_id = self.project_id,

)
ask_request.query_id = str(uuid.uuid4().int >> 65)
await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata)
# getting ask result
ask_result_response = self.ask_service_var.get_ask_result(
AskResultRequest(
query_id=ask_request.query_id,
)
)

prediction["actual_output"] = actual_output
prediction["retrieval_context"] = extract_units(documents)
while (
ask_result_response.status != "finished"
and ask_result_response.status != "failed"
):
# getting ask result
ask_result_response = self.ask_service_var.get_ask_result(
AskResultRequest(
query_id=ask_request.query_id,
)
)

# result = await self._retrieval.run(query=prediction["input"])
# documents = result.get("construct_retrieval_results", [])
# actual_output = await self._generation.run(
# query=prediction["input"],
# contexts=documents,
# samples=prediction["samples"],
# exclude=[],
# )

prediction["actual_output"] = ask_result_response.response[0].sql
#prediction["retrieval_context"] = extract_units(documents)
Comment on lines +398 to +445
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add timeout, backoff, and error handling to the polling loop.

The current implementation has several issues:

  1. No timeout mechanism for the polling loop
  2. No backoff strategy between retries
  3. No error handling for service calls

Add these safety mechanisms:

     async def _process(self, prediction: dict, **_) -> dict:
+        MAX_RETRIES = 10
+        INITIAL_BACKOFF = 1  # seconds
+        MAX_BACKOFF = 32  # seconds
+
+        try:
             await self.indexing_service_var.prepare_semantics(
                 SemanticsPreparationRequest(
                     mdl=self.mdl_str_var,
                     mdl_hash=self.mdl_hash,
                     project_id=self.project_id
                 ),
                 service_metadata=self.service_metadata,
             )
 
             ask_request = AskRequest(
                 query=prediction["input"],
                 mdl_hash=self.mdl_hash,
                 project_id=self.project_id,
             )
             ask_request.query_id = str(uuid.uuid4().int >> 65)
             await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata)
 
             ask_result_response = self.ask_service_var.get_ask_result(
                 AskResultRequest(
                     query_id=ask_request.query_id,
                 )
             )
 
+            retries = 0
+            backoff = INITIAL_BACKOFF
             while (
                 ask_result_response.status != "finished"
                 and ask_result_response.status != "failed"
+                and retries < MAX_RETRIES
             ):
+                await asyncio.sleep(backoff)
+                backoff = min(backoff * 2, MAX_BACKOFF)
+                retries += 1
+
                 ask_result_response = self.ask_service_var.get_ask_result(
                     AskResultRequest(
                         query_id=ask_request.query_id,
                     )
                 )
 
+            if retries >= MAX_RETRIES:
+                raise TimeoutError("Ask service request timed out")
+
+            if ask_result_response.status == "failed":
+                raise RuntimeError(f"Ask service request failed: {ask_result_response.error}")
+
             prediction["actual_output"] = ask_result_response.response[0].sql
+        except Exception as e:
+            logger.exception("Error in _process: %s", str(e))
+            prediction["actual_output"] = None
+            prediction["error"] = str(e)
 
         return prediction
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
await self.indexing_service_var.prepare_semantics(
SemanticsPreparationRequest(
mdl=self.mdl_str_var,
mdl_hash=self.mdl_hash,
project_id=self.project_id
),
service_metadata=self.service_metadata,
)
# asking
ask_request = AskRequest(
query=prediction["input"],
contexts=documents,
samples=prediction["samples"],
exclude=[],
mdl_hash=self.mdl_hash,
project_id = self.project_id,
)
ask_request.query_id = str(uuid.uuid4().int >> 65)
await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata)
# getting ask result
ask_result_response = self.ask_service_var.get_ask_result(
AskResultRequest(
query_id=ask_request.query_id,
)
)
prediction["actual_output"] = actual_output
prediction["retrieval_context"] = extract_units(documents)
while (
ask_result_response.status != "finished"
and ask_result_response.status != "failed"
):
# getting ask result
ask_result_response = self.ask_service_var.get_ask_result(
AskResultRequest(
query_id=ask_request.query_id,
)
)
# result = await self._retrieval.run(query=prediction["input"])
# documents = result.get("construct_retrieval_results", [])
# actual_output = await self._generation.run(
# query=prediction["input"],
# contexts=documents,
# samples=prediction["samples"],
# exclude=[],
# )
prediction["actual_output"] = ask_result_response.response[0].sql
#prediction["retrieval_context"] = extract_units(documents)
async def _process(self, prediction: dict, **_) -> dict:
MAX_RETRIES = 10
INITIAL_BACKOFF = 1 # seconds
MAX_BACKOFF = 32 # seconds
try:
await self.indexing_service_var.prepare_semantics(
SemanticsPreparationRequest(
mdl=self.mdl_str_var,
mdl_hash=self.mdl_hash,
project_id=self.project_id
),
service_metadata=self.service_metadata,
)
ask_request = AskRequest(
query=prediction["input"],
mdl_hash=self.mdl_hash,
project_id=self.project_id,
)
ask_request.query_id = str(uuid.uuid4().int >> 65)
await self.ask_service_var.ask(ask_request, service_metadata=self.service_metadata)
ask_result_response = self.ask_service_var.get_ask_result(
AskResultRequest(
query_id=ask_request.query_id,
)
)
retries = 0
backoff = INITIAL_BACKOFF
while (
ask_result_response.status != "finished"
and ask_result_response.status != "failed"
and retries < MAX_RETRIES
):
await asyncio.sleep(backoff)
backoff = min(backoff * 2, MAX_BACKOFF)
retries += 1
ask_result_response = self.ask_service_var.get_ask_result(
AskResultRequest(
query_id=ask_request.query_id,
)
)
if retries >= MAX_RETRIES:
raise TimeoutError("Ask service request timed out")
if ask_result_response.status == "failed":
raise RuntimeError(f"Ask service request failed: {ask_result_response.error}")
prediction["actual_output"] = ask_result_response.response[0].sql
except Exception as e:
logger.exception("Error in _process: %s", str(e))
prediction["actual_output"] = None
prediction["error"] = str(e)
return prediction


return prediction

Expand Down Expand Up @@ -377,9 +482,10 @@ def init(
name: Literal["retrieval", "generation", "ask"],
meta: dict,
mdl: dict,
providers: Dict[str, Any],
service_metadata,
pipe_components: Dict[str, Any],
) -> Eval:
args = {"meta": meta, "mdl": mdl, **providers}
args = {"meta": meta, "mdl": mdl, "service_metadata":service_metadata,"pipe_components":pipe_components}
match name:
case "retrieval":
return RetrievalPipeline(**args)
Expand Down
35 changes: 20 additions & 15 deletions wren-ai-service/eval/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@
from tomlkit import document, dumps

sys.path.append(f"{Path().parent.resolve()}")
from src.config import settings
from src.providers import generate_components
import eval.pipelines as pipelines
import src.providers as provider
import src.utils as utils
from eval.utils import parse_toml
from src.core.engine import EngineConfig
from src.core.provider import EmbedderProvider, LLMProvider

from src.globals import (
create_service_container,
create_service_metadata,
)

def generate_meta(
path: str,
Expand All @@ -46,10 +50,10 @@ def generate_meta(
"commit": obtain_commit_hash(),
"embedding_model": embedder_provider.get_model(),
"generation_model": llm_provider.get_model(),
"column_indexing_batch_size": int(os.getenv("COLUMN_INDEXING_BATCH_SIZE"))
"column_indexing_batch_size": int(settings.column_indexing_batch_size)
or 50,
"table_retrieval_size": int(os.getenv("TABLE_RETRIEVAL_SIZE")) or 10,
"table_column_retrieval_size": int(os.getenv("TABLE_COLUMN_RETRIEVAL_SIZE"))
"table_retrieval_size": int(settings.table_retrieval_size) or 10,
"table_column_retrieval_size": int(settings.table_column_retrieval_size)
or 100,
"pipeline": pipe,
"batch_size": os.getenv("BATCH_SIZE") or 4,
Expand Down Expand Up @@ -138,12 +142,12 @@ def init_providers(mdl: dict) -> dict:
if engine_config is None:
raise ValueError("Invalid datasource")

providers = provider.init_providers(engine_config=engine_config)
providers_inner = provider.init_providers(engine_config=engine_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Undefined name provider.
This call to provider.init_providers will fail unless provider is imported or defined. Fix this reference.

- providers_inner = provider.init_providers(engine_config=engine_config)
+ from src.providers import provider
+ providers_inner = provider.init_providers(engine_config=engine_config)

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.8.2)

145-145: Undefined name provider

(F821)

return {
"llm_provider": providers[0],
"embedder_provider": providers[1],
"document_store_provider": providers[2],
"engine": providers[3],
"llm_provider": providers_inner[0],
"embedder_provider": providers_inner[1],
"document_store_provider": providers_inner[2],
"engine": providers_inner[3],
}


Expand Down Expand Up @@ -174,23 +178,24 @@ def parse_args() -> Tuple[str]:
utils.init_langfuse()

dataset = parse_toml(path)
providers = init_providers(dataset["mdl"])

pipe_components = generate_components(settings.components)
meta = generate_meta(
path=path,
dataset=dataset,
pipe=pipe_name,
**providers,
**pipe_components["db_schema_retrieval"],
)

service_metadata = create_service_metadata(pipe_components)
pipe = pipelines.init(
pipe_name,
meta,
mdl=dataset["mdl"],
providers=providers,
service_metadata=service_metadata,
pipe_components=pipe_components,
Comment on lines +182 to +195
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for component generation.

The component generation and service metadata creation could fail silently. Consider adding error handling:

-    pipe_components = generate_components(settings.components)
-    meta = generate_meta(
-        path=path,
-        dataset=dataset,
-        pipe=pipe_name,
-        **pipe_components["db_schema_retrieval"],
-    )
-    service_metadata = create_service_metadata(pipe_components)
+    try:
+        pipe_components = generate_components(settings.components)
+        if not pipe_components.get("db_schema_retrieval"):
+            raise ValueError("Required component 'db_schema_retrieval' not found")
+        
+        meta = generate_meta(
+            path=path,
+            dataset=dataset,
+            pipe=pipe_name,
+            **pipe_components["db_schema_retrieval"],
+        )
+        service_metadata = create_service_metadata(pipe_components)
+    except Exception as e:
+        raise RuntimeError(f"Failed to initialize components: {str(e)}") from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
pipe_components = generate_components(settings.components)
meta = generate_meta(
path=path,
dataset=dataset,
pipe=pipe_name,
**providers,
**pipe_components["db_schema_retrieval"],
)
service_metadata = create_service_metadata(pipe_components)
pipe = pipelines.init(
pipe_name,
meta,
mdl=dataset["mdl"],
providers=providers,
service_metadata=service_metadata,
pipe_components=pipe_components,
try:
pipe_components = generate_components(settings.components)
if not pipe_components.get("db_schema_retrieval"):
raise ValueError("Required component 'db_schema_retrieval' not found")
meta = generate_meta(
path=path,
dataset=dataset,
pipe=pipe_name,
**pipe_components["db_schema_retrieval"],
)
service_metadata = create_service_metadata(pipe_components)
except Exception as e:
raise RuntimeError(f"Failed to initialize components: {str(e)}") from e
pipe = pipelines.init(
pipe_name,
meta,
mdl=dataset["mdl"],
service_metadata=service_metadata,
pipe_components=pipe_components,

)

predictions = pipe.predict(dataset["eval_dataset"])
predictions = pipe.predict([dataset["eval_dataset"][0]])
meta["expected_batch_size"] = meta["query_count"] * pipe.candidate_size
meta["actual_batch_size"] = len(predictions) - meta["query_count"]

Expand Down
7 changes: 7 additions & 0 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ class ServiceContainer:
class ServiceMetadata:
pipes_metadata: dict
service_version: str
def get(self, key: str):
if key=="service_version":
return self.service_version
elif key=="pipes_metadata":
return self.pipes_metadata
else:
return None


def create_service_container(
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ async def _task(result: Dict[str, str]):

if no_error:
status, _, addition = await self._engine.execute_sql(
quoted_sql, session, project_id=project_id
quoted_sql, session, project_id=int(project_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

🛠️ Refactor suggestion

Unsafe type casting of project_id needs revision

The type casting of project_id to integer is problematic because:

  • It's inconsistent with the codebase where project_id is handled as an optional string
  • It can raise ValueError for None values, which are valid according to the function signatures

Suggested fix:

quoted_sql, session, project_id=str(project_id) if project_id is not None else None
🔗 Analysis chain

Verify project_id type casting and add error handling.

The change to cast project_id to integer might fix the evaluation error, but it could raise a ValueError if project_id is None or not a valid integer string.

Add error handling to safely handle invalid project_id values:

-                    quoted_sql, session, project_id=int(project_id)
+                    quoted_sql, session, project_id=int(project_id) if project_id is not None else None
🏁 Scripts executed

The following scripts were executed for the analysis:

Error: No shell scripts found


Script:

#!/bin/bash
# Find execute_sql function definition and its usage
ast-grep --pattern 'def execute_sql($$$)'

# Search for project_id parameter in SQL execution contexts
rg "execute_sql.*project_id" -A 2

# Look for None checks on project_id
rg "project_id.*None" -A 2
rg "if.*project_id" -A 2

Length of output: 17594

)

if status:
Expand Down