From a9b872f5b2ca2e78b1840916dbbfef037a756cf0 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv <168006707+dnandakumar-nv@users.noreply.github.com> Date: Thu, 19 Sep 2024 17:54:46 -0500 Subject: [PATCH] Remove OpenAI dependency (#183) --- .../event-driven-rag-cve-analysis/Dockerfile | 3 + .../event-driven-rag-cve-analysis/README.md | 10 +- .../cyber_dev_day/config.py | 12 - .../cyber_dev_day/llm_service.py | 2 +- .../cyber_dev_day/nim_llm_service.py | 284 ++++++++++--- .../cyber_dev_day/openai_chat_service.py | 397 ------------------ .../cyber_dev_day/pipeline_utils.py | 168 +------- .../docker-compose.yml | 2 +- .../notebooks/cyber-dev-day.ipynb | 236 ++--------- 9 files changed, 284 insertions(+), 830 deletions(-) delete mode 100644 community/event-driven-rag-cve-analysis/cyber_dev_day/openai_chat_service.py diff --git a/community/event-driven-rag-cve-analysis/Dockerfile b/community/event-driven-rag-cve-analysis/Dockerfile index c7dc661d..65c50a75 100755 --- a/community/event-driven-rag-cve-analysis/Dockerfile +++ b/community/event-driven-rag-cve-analysis/Dockerfile @@ -66,5 +66,8 @@ RUN source activate morpheus &&\ jupyter contrib nbextension install --user &&\ pip install jupyterlab_nvdashboard==0.9 +RUN source activate morpheus &&\ + pip install --upgrade langchain-nvidia-ai-endpoints + # Launch jupyter CMD ["jupyter-lab", "--ip=0.0.0.0", "--no-browser", "--allow-root"] diff --git a/community/event-driven-rag-cve-analysis/README.md b/community/event-driven-rag-cve-analysis/README.md index 9d39a553..e5719cff 100644 --- a/community/event-driven-rag-cve-analysis/README.md +++ b/community/event-driven-rag-cve-analysis/README.md @@ -28,11 +28,7 @@ You will also need to have a `Morpheus 24.03` docker container built and present ### NVIDIA GPU Cloud -To access the NVIDIA hosted Inference Service, you will need to have the following environment variables set: `OPENAI_API_KEY`. To obtain the API key, please visit the [NVIDIA website](https://build.nvidia.com/) for instructions on generating your API key. - -It's important to note here that although we store the NGC API Key under the `OPENAI_API_KEY` variable, we will be interacting with NVIDIA hosted LLMs and not OpenAI LLMs. - -NVIDIA NIM microservices are OpenAI API compliant to maximize usability, so we will be using the `openai` with package as a wrapped to make API calls. +To access the NVIDIA hosted Inference Service, you will need to have the following environment variables set: `NVIDIA_API_KEY`. To obtain the API key, please visit the [NVIDIA website](https://build.nvidia.com/) for instructions on generating your API key. ### Building a Morpheus Container @@ -53,13 +49,13 @@ If you are using a Morpheus version that is not `v24.03.02-runtime`, please upda ``` ### Creating an Environment File -To automatically use these API keys, you can set the `OPENAI_API_KEY` value in the `docker-compose.yml` file in this directory as follows: +To automatically use these API keys, you can set the `NVIDIA_API_KEY` value in the `docker-compose.yml` file in this directory as follows: ```bash environment: - TERM=${TERM:-} # Workaround until this is working: https://github.com/docker/compose/issues/9181#issuecomment-1996016211 - - OPENAI_API_KEY= + - NVIDIA_API_KEY= # Overwrite any environment variables in the .env file with URLs needed in the network - OPENAI_API_BASE=https://integrate.api.nvidia.com/v1 - OPENAI_BASE_URL=https://integrate.api.nvidia.com/v1 diff --git a/community/event-driven-rag-cve-analysis/cyber_dev_day/config.py b/community/event-driven-rag-cve-analysis/cyber_dev_day/config.py index 30249e4a..7dc3e197 100644 --- a/community/event-driven-rag-cve-analysis/cyber_dev_day/config.py +++ b/community/event-driven-rag-cve-analysis/cyber_dev_day/config.py @@ -55,16 +55,6 @@ class NVFoundationLLMModelConfig(BaseModel): temperature: float = 0.0 -class OpenAIServiceConfig(BaseModel): - type: typing.Literal["openai"] = "openai" - - -class OpenAIMModelConfig(BaseModel): - service: OpenAIServiceConfig - - model_name: str - - class NIMServiceConfig(BaseModel): type: typing.Literal["NIM"] = "NIM" @@ -73,13 +63,11 @@ class NIMModelConfig(BaseModel): service: NIMServiceConfig model_name: str - base_url: str temperature: float = 0.0 top_p: float = 1 LLMModelConfig = typing.Annotated[typing.Annotated[NeMoLLMModelConfig, Tag("nemo")] - | typing.Annotated[OpenAIMModelConfig, Tag("openai")] | typing.Annotated[NVFoundationLLMModelConfig, Tag("nvfoundation")] | typing.Annotated[NIMModelConfig, Tag("NIM")], Discriminator(_llm_discriminator)] diff --git a/community/event-driven-rag-cve-analysis/cyber_dev_day/llm_service.py b/community/event-driven-rag-cve-analysis/cyber_dev_day/llm_service.py index 7777d7ea..34b888bd 100644 --- a/community/event-driven-rag-cve-analysis/cyber_dev_day/llm_service.py +++ b/community/event-driven-rag-cve-analysis/cyber_dev_day/llm_service.py @@ -153,7 +153,7 @@ def create(service_type: str, *service_args, **service_kwargs) -> "LLMService": pass @staticmethod - def create(service_type: str | typing.Literal["nemo"] | typing.Literal["openai"], *service_args, **service_kwargs): + def create(service_type: str | typing.Literal["nemo"] | typing.Literal["nim"], *service_args, **service_kwargs): """ Returns a service for interacting with LLM models. diff --git a/community/event-driven-rag-cve-analysis/cyber_dev_day/nim_llm_service.py b/community/event-driven-rag-cve-analysis/cyber_dev_day/nim_llm_service.py index 621befbe..cb42ede5 100644 --- a/community/event-driven-rag-cve-analysis/cyber_dev_day/nim_llm_service.py +++ b/community/event-driven-rag-cve-analysis/cyber_dev_day/nim_llm_service.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio import copy import logging @@ -19,84 +20,202 @@ import typing from contextlib import contextmanager from textwrap import dedent -import re import appdirs from cyber_dev_day.llm_service import LLMClient from cyber_dev_day.llm_service import LLMService -from cyber_dev_day.openai_chat_service import OpenAIChatService, OpenAIChatClient logger = logging.getLogger(__name__) IMPORT_EXCEPTION = None -IMPORT_ERROR_MESSAGE = ("OpenAIChatService & OpenAIChatClient require the openai package to be installed. " - "Install it by running the following command:\n" - "`conda env update --solver=libmamba -n morpheus " - "--file conda/environments/dev_cuda-121_arch-x86_64.yaml --prune`") +IMPORT_ERROR_MESSAGE = ( + "ChatNVIDIA library from Langchain is a required installation. %pip install --upgrade --quiet langchain-nvidia-ai-endpoints") try: - import openai - import openai.types.chat - import openai.types.chat.chat_completion + from langchain_nvidia_ai_endpoints import ChatNVIDIA except ImportError as import_exc: IMPORT_EXCEPTION = import_exc -class NIMChatClient(OpenAIChatClient): +class ChatNVIDIAClient(LLMClient): """ - Client for interacting with a specific NVIDIA Inference Microservice chat model. This class should be constructed with the - `NIMLLMService.get_client` method. - - Parameters - ---------- - model_name : str - The name of the model to interact with. - - base_url: str - The URI at which the NIM can be reached. - - set_assistant: bool, optional default=False - When `True`, a second input field named `assistant` will be used to proide additional context to the model. - - max_retries: int, optional default=10 - The maximum number of retries to attempt when making a request to the OpenAI API. - - model_kwargs : dict[str, typing.Any] - Additional keyword arguments to pass to the model when generating text. + Client for interacting with ChatNVIDIA models through the Langchain-NVIDIA AI Endpoints. """ - _prompt_key: str = "prompt" - _assistant_key: str = "assistant" - def __init__(self, - parent: "NIMChatService", + parent: "NIMLLMService", *, model_name: str, - base_url: str, set_assistant: bool = False, max_retries: int = 10, + temperature: float = 0.1, + top_p: float = 0.0, + api_key=os.getenv('NVIDIA_API_KEY'), **model_kwargs) -> None: + """ + Initialize the ChatNVIDIAClient. + + Parameters + ---------- + parent : ChatNVIDIAService + The service instance creating this client. + model_name : str + The name of the model to interact with. + set_assistant : bool, optional + Flag indicating if the assistant role should be set, by default False. + max_retries : int, optional + Maximum number of retries for API requests, by default 10. + api_key : str, optional + API key for authenticating with the NVIDIA service, by default from the 'NGC_API_KEY' environment variable. + model_kwargs : dict + Additional model-specific keyword arguments. + """ if IMPORT_EXCEPTION is not None: raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION - super().__init__( - parent=parent, - model_name=model_name, - set_assistant=set_assistant, - max_retries=max_retries, - **model_kwargs - ) - - self._base_url = base_url - - # Create the client objects for both sync and async - self._client = openai.OpenAI(base_url = self._base_url, max_retries=max_retries) - self._client_async = openai.AsyncOpenAI(base_url = self._base_url, max_retries=max_retries) - - - -class NIMLLMService(OpenAIChatService): + super().__init__() + + self._model_name = model_name + self._set_assistant = set_assistant + self._prompt_key = "prompt" + self._assistant_key = "assistant" + self._model_kwargs = copy.deepcopy(model_kwargs) + + print(f"Initializing chat client with temperature {temperature}") + + self._client = ChatNVIDIA(model=self._model_name, nvidia_api_key=api_key, temperature=temperature, top_p=top_p) + + async def _generate_async(self, prompt: str, assistant: str = None) -> str: + """ + Generate async call to NIM using ChatNVIDIA client. + + Parameters + ---------- + prompt : str + The prompt to generate a response for. + assistant : str, optional + The assistant text to guide the response, by default None. + + Returns + ------- + str + The generated response content. + """ + output = await self._client.ainvoke(prompt) + return output.content + + async def generate_batch_async(self, + inputs: dict[str, list], + return_exceptions=False) -> list[str] | list[str | BaseException]: + """ + Generate a batch of asynchronous requests to ChatNVIDIA. + + Parameters + ---------- + inputs : dict[str, list] + Dictionary containing the prompts for batch processing. + return_exceptions : bool, optional + If True, exceptions during the async operations will be returned, by default False. + + Returns + ------- + list[str] | list[str | BaseException] + List of generated responses or exceptions. + """ + prompts = inputs[self._prompt_key] + + coros = [self._generate_async(prompt) for prompt in prompts] + + return await asyncio.gather(*coros, return_exceptions=return_exceptions) + + def generate(self, **input_dict) -> str: + """ + Issue a request to generate a response based on a given prompt. + + Parameters + ---------- + input_dict : dict + Input containing prompt data. + + Returns + ------- + str + The generated response content. + """ + return self._client.invoke(input_dict[self._prompt_key]) + + async def generate_async(self, **input_dict) -> str: + """ + Issue an asynchronous request to generate a response based on a given prompt. + + Parameters + ---------- + input_dict : dict + Input containing prompt data. + + Returns + ------- + str + The generated response content. + """ + return await self._generate_async(input_dict[self._prompt_key], input_dict.get(self._assistant_key)) + + def generate_batch(self, inputs: dict[str, list], return_exceptions=False) -> list[str] | list[str | BaseException]: + """ + Generate a batch of requests to ChatNVIDIA. + + Parameters + ---------- + inputs : dict[str, list] + Dictionary containing the prompts for batch processing. + return_exceptions : bool, optional + If True, exceptions during the operations will be returned, by default False. + + Returns + ------- + list[str] | list[str | BaseException] + List of generated responses or exceptions. + """ + prompts = inputs[self._prompt_key] + results = [self._generate(prompt) for prompt in prompts] + + return results + + def _generate(self, prompt: str) -> str: + """ + Issue a request to generate a response based on a given prompt. + + Parameters + ---------- + prompt : str + The prompt to generate a response for. + + Returns + ------- + str + The generated response content. + """ + output = self._client.invoke(prompt) + return self._client.invoke(prompt) + + def get_input_names(self) -> list[str]: + """ + Get the names of the required inputs for the model. + + Returns + ------- + list[str] + List of input names required by the model. + """ + input_names = [self._prompt_key] + if self._set_assistant: + input_names.append(self._assistant_key) + + return input_names + + +class NIMLLMService(LLMService): """ A service for interacting with NIM Chat models, this class should be used to create clients. """ @@ -116,7 +235,7 @@ def __init__(self, *, default_model_kwargs: dict = None) -> None: Raises ------ ImportError - If the `openai` library is not found in the python environment. + If the `langchain-nvidia-ai-endpoints` library is not found in the python environment. """ if IMPORT_EXCEPTION is not None: raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION @@ -125,14 +244,43 @@ def __init__(self, *, default_model_kwargs: dict = None) -> None: self._default_model_kwargs = default_model_kwargs or {} - + self._logger = logging.getLogger(f"{__package__}.{NIMLLMService.__name__}") + + # Don't propagate up to the default logger. Just log to file + self._logger.propagate = False + + log_file = os.path.join(appdirs.user_log_dir(appauthor="NVIDIA", appname="morpheus"), "openai.log") + + # Add a file handler + file_handler = logging.FileHandler(log_file) + + self._logger.addHandler(file_handler) + self._logger.setLevel(logging.INFO) + + self._logger.info("NIM LLM Chat Service started.") + + self._message_count = 0 + + def _get_message_id(self) -> int: + """ + Get a unique message ID for logging purposes. + + Returns + ------- + int + A unique message ID. + """ + self._message_count += 1 + + return self._message_count + def get_client(self, *, model_name: str, - base_url: str, - set_assistant: bool = False, max_retries: int = 10, - **model_kwargs) -> NIMChatClient: + temperature: float = 0.1, + top_p: float = 0.0, + **model_kwargs) -> ChatNVIDIAClient: """ Returns a client for interacting with a specific model. This method is the preferred way to create a client. @@ -140,26 +288,22 @@ def get_client(self, ---------- model_name : str The name of the model to create a client for. - - base_url: str - The URI at which the NIM can be reached. - - set_assistant: bool, optional default=False - When `True`, a second input field named `assistant` will be used to proide additional context to the model. - - max_retries: int, optional default=10 + max_retries: int, optional The maximum number of retries to attempt when making a request to the OpenAI API. - model_kwargs : dict[str, typing.Any] Additional keyword arguments to pass to the model when generating text. Arguments specified here will - overwrite the `default_model_kwargs` set in the service constructor - """ + overwrite the `default_model_kwargs` set in the service constructor. + Returns + ------- + ChatNVIDIAClient + A client instance configured for the specified model. + """ final_model_kwargs = {**self._default_model_kwargs, **model_kwargs} - return NIMChatClient(self, + return ChatNVIDIAClient(self, model_name=model_name, - base_url=base_url, - set_assistant=set_assistant, max_retries=max_retries, + temperature=temperature, + top_p=top_p, **final_model_kwargs) \ No newline at end of file diff --git a/community/event-driven-rag-cve-analysis/cyber_dev_day/openai_chat_service.py b/community/event-driven-rag-cve-analysis/cyber_dev_day/openai_chat_service.py deleted file mode 100644 index 376a6746..00000000 --- a/community/event-driven-rag-cve-analysis/cyber_dev_day/openai_chat_service.py +++ /dev/null @@ -1,397 +0,0 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import copy -import logging -import os -import time -import typing -from contextlib import contextmanager -from textwrap import dedent - -import appdirs - -from cyber_dev_day.llm_service import LLMClient -from cyber_dev_day.llm_service import LLMService - -logger = logging.getLogger(__name__) - -IMPORT_EXCEPTION = None -IMPORT_ERROR_MESSAGE = ("OpenAIChatService & OpenAIChatClient require the openai package to be installed. " - "Install it by running the following command:\n" - "`conda env update --solver=libmamba -n morpheus " - "--file conda/environments/dev_cuda-121_arch-x86_64.yaml --prune`") - -try: - import openai - import openai.types.chat - import openai.types.chat.chat_completion -except ImportError as import_exc: - IMPORT_EXCEPTION = import_exc - - -class _ApiLogger: - """ - Simple class that allows passing back and forth the inputs and outputs of an API call via a context manager. - """ - - log_template: typing.ClassVar[str] = dedent(""" - ============= MESSAGE %d START ============== - --- Input --- - %s - --- Output --- (%f ms) - %s - ============= MESSAGE %d END ============== - """).strip("\n") - - def __init__(self, *, message_id: int, inputs: typing.Any) -> None: - - self.message_id = message_id - self.inputs = inputs - self.outputs = None - - def set_output(self, output: typing.Any) -> None: - self.outputs = output - - -class OpenAIChatClient(LLMClient): - """ - Client for interacting with a specific OpenAI chat model. This class should be constructed with the - `OpenAIChatService.get_client` method. - - Parameters - ---------- - model_name : str - The name of the model to interact with. - - set_assistant: bool, optional default=False - When `True`, a second input field named `assistant` will be used to proide additional context to the model. - - max_retries: int, optional default=10 - The maximum number of retries to attempt when making a request to the OpenAI API. - - model_kwargs : dict[str, typing.Any] - Additional keyword arguments to pass to the model when generating text. - """ - - _prompt_key: str = "prompt" - _assistant_key: str = "assistant" - - def __init__(self, - parent: "OpenAIChatService", - *, - model_name: str, - set_assistant: bool = False, - max_retries: int = 10, - **model_kwargs) -> None: - if IMPORT_EXCEPTION is not None: - raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION - - super().__init__() - - assert parent is not None, "Parent service cannot be None." - - self._parent = parent - - self._model_name = model_name - self._set_assistant = set_assistant - self._prompt_key = "prompt" - self._assistant_key = "assistant" - - # Preserve original configuration. - self._model_kwargs = copy.deepcopy(model_kwargs) - - # Create the client objects for both sync and async - self._client = openai.OpenAI(max_retries=max_retries) - self._client_async = openai.AsyncOpenAI(max_retries=max_retries) - - def get_input_names(self) -> list[str]: - input_names = [self._prompt_key] - if self._set_assistant: - input_names.append(self._assistant_key) - - return input_names - - @contextmanager - def _api_logger(self, inputs: typing.Any): - - message_id = self._parent._get_message_id() - start_time = time.time() - - api_logger = _ApiLogger(message_id=message_id, inputs=inputs) - - yield api_logger - - end_time = time.time() - duration_ms = (end_time - start_time) * 1000.0 - - self._parent._logger.info(_ApiLogger.log_template, - message_id, - api_logger.inputs, - duration_ms, - api_logger.outputs, - message_id) - - def _create_messages(self, - prompt: str, - assistant: str = None) -> list["openai.types.chat.ChatCompletionMessageParam"]: - messages: list[openai.types.chat.ChatCompletionMessageParam] = [{"role": "user", "content": prompt}] - - if (self._set_assistant and assistant is not None): - messages.append({"role": "assistant", "content": assistant}) - - return messages - - def _extract_completion(self, completion: "openai.types.chat.chat_completion.ChatCompletion") -> str: - choices = completion.choices - if len(choices) == 0: - raise ValueError("No choices were returned from the model.") - - content = choices[0].message.content - if content is None: - raise ValueError("No content was returned from the model.") - - return content - - @typing.overload - def _generate(self, - prompt: str, - assistant: str = None, - return_exceptions: typing.Literal[True] = True) -> str | BaseException: - ... - - @typing.overload - def _generate(self, prompt: str, assistant: str = None, return_exceptions: typing.Literal[False] = False) -> str: - ... - - def _generate(self, prompt: str, assistant: str = None, return_exceptions: bool = False): - - try: - messages = self._create_messages(prompt, assistant) - - output: openai.types.chat.chat_completion.ChatCompletion = self._client.chat.completions.create( - model=self._model_name, messages=messages, **self._model_kwargs) - - return self._extract_completion(output) - except BaseException as e: - - if return_exceptions: - return e - - raise - - def generate(self, **input_dict) -> str: - """ - Issue a request to generate a response based on a given prompt. - - Parameters - ---------- - input_dict : dict - Input containing prompt data. - """ - return self._generate(input_dict[self._prompt_key], - input_dict.get(self._assistant_key), - return_exceptions=False) - - async def _generate_async(self, prompt: str, assistant: str = None) -> str: - - messages = self._create_messages(prompt, assistant) - - with self._api_logger(inputs=messages) as msg_logger: - - try: - output = await self._client_async.chat.completions.create(model=self._model_name, - messages=messages, - **self._model_kwargs) - except Exception as exc: - self._parent._logger.error("Error generating completion: %s", exc) - raise - - msg_logger.set_output(output) - - return self._extract_completion(output) - - async def generate_async(self, **input_dict) -> str: - """ - Issue an asynchronous request to generate a response based on a given prompt. - - Parameters - ---------- - input_dict : dict - Input containing prompt data. - """ - return await self._generate_async(input_dict[self._prompt_key], input_dict.get(self._assistant_key)) - - @typing.overload - def generate_batch(self, - inputs: dict[str, list], - return_exceptions: typing.Literal[True] = True, **kwargs) -> list[str | BaseException]: - ... - - @typing.overload - def generate_batch(self, inputs: dict[str, list], return_exceptions: typing.Literal[False] = False, **kwargs) -> list[str]: - ... - - def generate_batch(self, inputs: dict[str, list], return_exceptions=False, **kwargs) -> list[str] | list[str | BaseException]: - """ - Issue a request to generate a list of responses based on a list of prompts. - - Parameters - ---------- - inputs : dict - Inputs containing prompt data. - return_exceptions : bool - Whether to return exceptions in the output list or raise them immediately. - """ - prompts = inputs[self._prompt_key] - assistants = None - if (self._set_assistant): - assistants = inputs[self._assistant_key] - if len(prompts) != len(assistants): - raise ValueError("The number of prompts and assistants must be equal.") - - results = [] - for (i, prompt) in enumerate(prompts): - assistant = assistants[i] if assistants is not None else None - if (return_exceptions): - results.append(self._generate(prompt, assistant, return_exceptions=True, **kwargs)) - else: - results.append(self._generate(prompt, assistant, return_exceptions=False, **kwargs)) - - return results - - @typing.overload - async def generate_batch_async(self, - inputs: dict[str, list], - return_exceptions: typing.Literal[True] = True, **kwargs) -> list[str | BaseException]: - ... - - @typing.overload - async def generate_batch_async(self, - inputs: dict[str, list], - return_exceptions: typing.Literal[False] = False, **kwargs) -> list[str]: - ... - - async def generate_batch_async(self, - inputs: dict[str, list], - return_exceptions=False, **kwargs) -> list[str] | list[str | BaseException]: - """ - Issue an asynchronous request to generate a list of responses based on a list of prompts. - - Parameters - ---------- - inputs : dict - Inputs containing prompt data. - return_exceptions : bool - Whether to return exceptions in the output list or raise them immediately. - """ - prompts = inputs[self._prompt_key] - assistants = None - if (self._set_assistant): - assistants = inputs[self._assistant_key] - if len(prompts) != len(assistants): - raise ValueError("The number of prompts and assistants must be equal.") - - coros = [] - for (i, prompt) in enumerate(prompts): - assistant = assistants[i] if assistants is not None else None - coros.append(self._generate_async(prompt, assistant, **kwargs)) - - return await asyncio.gather(*coros, return_exceptions=return_exceptions, **kwargs) - - -class OpenAIChatService(LLMService): - """ - A service for interacting with OpenAI Chat models, this class should be used to create clients. - """ - - def __init__(self, *, default_model_kwargs: dict = None) -> None: - """ - Creates a service for interacting with OpenAI Chat models, this class should be used to create clients. - - Parameters - ---------- - default_model_kwargs : dict, optional - Default arguments to use when creating a client via the `get_client` function. Any argument specified here - will automatically be used when calling `get_client`. Arguments specified in the `get_client` function will - overwrite default values specified here. This is useful to set model arguments before creating multiple - clients. By default None - - Raises - ------ - ImportError - If the `openai` library is not found in the python environment. - """ - if IMPORT_EXCEPTION is not None: - raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION - - super().__init__() - - self._default_model_kwargs = default_model_kwargs or {} - - self._logger = logging.getLogger(f"{__package__}.{OpenAIChatService.__name__}") - - # Dont propagate up to the default logger. Just log to file - self._logger.propagate = False - - log_file = os.path.join(appdirs.user_log_dir(appauthor="NVIDIA", appname="morpheus"), "openai.log") - - # Add a file handler - file_handler = logging.FileHandler(log_file) - - self._logger.addHandler(file_handler) - self._logger.setLevel(logging.INFO) - - self._logger.info("OpenAI Chat Service started.") - - self._message_count = 0 - - def _get_message_id(self): - - self._message_count += 1 - - return self._message_count - - def get_client(self, - *, - model_name: str, - set_assistant: bool = False, - max_retries: int = 10, - **model_kwargs) -> OpenAIChatClient: - """ - Returns a client for interacting with a specific model. This method is the preferred way to create a client. - - Parameters - ---------- - model_name : str - The name of the model to create a client for. - - set_assistant: bool, optional default=False - When `True`, a second input field named `assistant` will be used to proide additional context to the model. - - max_retries: int, optional default=10 - The maximum number of retries to attempt when making a request to the OpenAI API. - - model_kwargs : dict[str, typing.Any] - Additional keyword arguments to pass to the model when generating text. Arguments specified here will - overwrite the `default_model_kwargs` set in the service constructor - """ - - final_model_kwargs = {**self._default_model_kwargs, **model_kwargs} - - return OpenAIChatClient(self, - model_name=model_name, - set_assistant=set_assistant, - max_retries=max_retries, - **final_model_kwargs) diff --git a/community/event-driven-rag-cve-analysis/cyber_dev_day/pipeline_utils.py b/community/event-driven-rag-cve-analysis/cyber_dev_day/pipeline_utils.py index 660d4e61..54733c1c 100644 --- a/community/event-driven-rag-cve-analysis/cyber_dev_day/pipeline_utils.py +++ b/community/event-driven-rag-cve-analysis/cyber_dev_day/pipeline_utils.py @@ -18,9 +18,12 @@ from langchain.agents import Tool from langchain.agents import initialize_agent from langchain.agents.agent import AgentExecutor -from langchain.chains import RetrievalQA from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.vectorstores.faiss import FAISS +from langchain import hub +from langchain.chains import create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain.tools.retriever import create_retriever_tool from morpheus.llm import LLMEngine from morpheus.llm.nodes.extracter_node import ExtracterNode @@ -71,146 +74,13 @@ def build_agent_executor(config: EngineAgentConfig, handle_parsing_errors=False) code_vector_db = FAISS.load_local(folder_path=config.code_repo.faiss_dir, embeddings=embeddings, allow_dangerous_deserialization=True) - code_qa_tool = RetrievalQA.from_chain_type(llm=langchain_llm, - chain_type="stuff", - retriever=code_vector_db.as_retriever()) - tools.append( - Tool(name="Docker Container Code QA System", - func=code_qa_tool.run, - description=("useful for when you need to review code to check for an import or function usage in " - "the Docker container. Input should be a question or the actual code. "))) - - sys_prompt = ("You are a very powerful assistant who helps investigate Docker containers " - " given a checklist of investigation items. Your role is to walk through a provided checklist and answer each item in the checklist. " - " Do not investigate additional information per checklist item, just answer the checklist. " - " Information about the Docker container under investigation is stored in vector databases available to you via tools. ") - - if handle_parsing_errors: - agent_executor = initialize_agent(tools, - langchain_llm, - agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, - verbose=config.verbose, - handle_parsing_errors="Check your output. Make sure you're using the right Action/Action input syntax.") - else: - agent_executor = initialize_agent(tools, - langchain_llm, - agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, - verbose=config.verbose) - - agent_executor.agent.llm_chain.prompt.template = ( - sys_prompt + ' ' + agent_executor.agent.llm_chain.prompt.template.replace( - "Answer the following questions as best you can.", - ("If the input is not a question, formulate it into a question first. " - "Include intermediate thought in the final answer.")).replace( - "Use the following format:", - ("Use the following format (start each response with one of the following prefixes): " - "[Question, Thought, Action, Action Input, Final Answer]). " - "If you are making an action, wait for a response to the action input before making an observation. Every response must contain at least one action (and thoughts and observations if you have them), but you cannot have both a final answer and an action in a response. Action input must only contain the exact input, do not provide any text following that in your response. Always end your response with either an action, or a final answer."))) - - return agent_executor - - -def build_cve_llm_engine(config: EngineConfig, handle_parsing_errors=True) -> LLMEngine: - engine = LLMEngine() - - engine.add_node("extracter", node=ExtracterNode()) - - engine.add_node("checklist", inputs=["/extracter"], node=CVEChecklistNode(config=config.checklist)) - - engine.add_node("agent", - inputs=[("/checklist")], - node=LangChainAgentNode(agent_executor=build_agent_executor(config=config.agent, - handle_parsing_errors=handle_parsing_errors))) - - engine.add_task_handler( - inputs=[("/checklist", "checklist"), ("/agent", "response")], - handler=SimpleTaskHandler(output_columns=["checklist", "response"]), - ) - - return engine - - -# Copyright (c) 2024, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from langchain.agents import AgentType -from langchain.agents import Tool -from langchain.agents import initialize_agent -from langchain.agents.agent import AgentExecutor -from langchain.chains import RetrievalQA -from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from langchain.vectorstores.faiss import FAISS - -from morpheus.llm import LLMEngine -from morpheus.llm.nodes.extracter_node import ExtracterNode -from morpheus.llm.nodes.langchain_agent_node import LangChainAgentNode -from cyber_dev_day.llm_service import LLMService -from cyber_dev_day.langchain_llm_client_wrapper import LangchainLLMClientWrapper -from morpheus.llm.task_handlers.simple_task_handler import SimpleTaskHandler - -from .checklist_node import CVEChecklistNode -from .config import EngineAgentConfig -from .config import EngineConfig -from .tools import SBOMChecker - -logger = logging.getLogger(__name__) - - -def build_agent_executor(config: EngineAgentConfig, handle_parsing_errors=False) -> AgentExecutor: - llm_service = LLMService.create(config.model.service.type, **config.model.service.model_dump(exclude={"type"})) - - llm_client = llm_service.get_client(**config.model.model_dump(exclude={"service"})) - - # Wrap the Morpheus client in a LangChain compatible wrapper - langchain_llm = LangchainLLMClientWrapper(client=llm_client) - - # tools = load_tools(["serpapi", "llm-math"], llm=llm) - tools: list[Tool] = [] - - if (config.sbom.data_file is not None): - # Load the SBOM - sbom_checker = SBOMChecker.from_csv(config.sbom.data_file) - - tools.append( - Tool(name="SBOM Package Checker", - func=sbom_checker.sbom_checker, - description=("useful for when you need to check the Docker container's software bill of " - "materials (SBOM) to get whether or not a given library is in the container. " - "Input should be the name of the library or software, and no text following it until a response is returned. " - "If the package is " - "present a version number is returned, otherwise False is returned if the " - "package is not present."))) - - if (config.code_repo.faiss_dir is not None): - embeddings = HuggingFaceEmbeddings(model_name=config.code_repo.embedding_model_name, - model_kwargs={'device': 'cuda'}, - encode_kwargs={'normalize_embeddings': False}) - - # load code vector DB - code_vector_db = FAISS.load_local(folder_path=config.code_repo.faiss_dir, - embeddings=embeddings, - allow_dangerous_deserialization=True) - code_qa_tool = RetrievalQA.from_chain_type(llm=langchain_llm, - chain_type="stuff", - retriever=code_vector_db.as_retriever()) - tools.append( - Tool(name="Docker Container Code QA System", - func=code_qa_tool.run, - description=("useful for when you need to review code to check for an import or function usage in " - "the Docker container. Input should be a question or the actual code. "))) + tools.append(create_retriever_tool( + code_vector_db.as_retriever(), + "Docker Container Code QA System", + ("useful for when you need to review code to check for an import or function usage in " + "the Docker container. Input should be a question or the actual code. ") + )) sys_prompt = ("You are a very powerful assistant who helps investigate Docker containers " " given a checklist of investigation items. Your role is to walk through a provided checklist and answer each item in the checklist. " @@ -222,7 +92,7 @@ def build_agent_executor(config: EngineAgentConfig, handle_parsing_errors=False) langchain_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=config.verbose, - handle_parsing_errors="Check your output. Make sure you're using the right Action/Action input syntax.") + handle_parsing_errors="Check your output. Each thought must end with a 'Final Anser', or 'Action'/'Action Input'. For action inputs, adhere to the tool description's instructions exactly. Thoughts cannot be empty strings.") else: agent_executor = initialize_agent(tools, langchain_llm, @@ -230,14 +100,14 @@ def build_agent_executor(config: EngineAgentConfig, handle_parsing_errors=False) verbose=config.verbose) agent_executor.agent.llm_chain.prompt.template = ( - sys_prompt + ' ' + agent_executor.agent.llm_chain.prompt.template.replace( - "Answer the following questions as best you can.", - ("If the input is not a question, formulate it into a question first. " - "Include intermediate thought in the final answer.")).replace( - "Use the following format:", - ("Use the following format (start each response with one of the following prefixes): " - "[Question, Thought, Action, Action Input, Final Answer]). " - "If you are making an action, wait for a response to the action input before making an observation. Every response must contain at least one action (and thoughts and observations if you have them), but you cannot have both a final answer and an action in a response. Action input must only contain the exact input, do not provide any text following that in your response. Always end your response with either an action, or a final answer."))) + sys_prompt + ' ' + agent_executor.agent.llm_chain.prompt.template.replace( + "Answer the following questions as best you can.", + ("If the input is not a question, formulate it into a question first. " + "Include intermediate thought in the final answer.")).replace( + "Use the following format:", + ("Use the following format (start each response with one of the following prefixes): " + "[Question, Thought, Action, Action Input, Final Answer]). " + "If you are making an action, wait for a response to the action input before making an observation. Every response must contain at least one action (and thoughts and observations if you have them), but you cannot have both a final answer and an action in a response. Action input must only contain the exact input, do not provide any text following that in your response. Always end your response with either an action, or a final answer."))) return agent_executor diff --git a/community/event-driven-rag-cve-analysis/docker-compose.yml b/community/event-driven-rag-cve-analysis/docker-compose.yml index 1376d260..cba6e620 100755 --- a/community/event-driven-rag-cve-analysis/docker-compose.yml +++ b/community/event-driven-rag-cve-analysis/docker-compose.yml @@ -48,7 +48,7 @@ services: environment: - TERM=${TERM:-} # Workaround until this is working: https://github.com/docker/compose/issues/9181#issuecomment-1996016211 - - OPENAI_API_KEY= + - NVIDIA_API_KEY= # Overwrite any environment variables in the .env file with URLs needed in the network - OPENAI_API_BASE=https://integrate.api.nvidia.com/v1 - OPENAI_BASE_URL=https://integrate.api.nvidia.com/v1 diff --git a/community/event-driven-rag-cve-analysis/notebooks/cyber-dev-day.ipynb b/community/event-driven-rag-cve-analysis/notebooks/cyber-dev-day.ipynb index 28e07a16..608ba124 100644 --- a/community/event-driven-rag-cve-analysis/notebooks/cyber-dev-day.ipynb +++ b/community/event-driven-rag-cve-analysis/notebooks/cyber-dev-day.ipynb @@ -135,7 +135,7 @@ "outputs": [], "source": [ "# Ensure that the current environment is set up with API keys\n", - "required_env_vars = [\"MORPHEUS_ROOT\", \"OPENAI_API_KEY\", \"OPENAI_BASE_URL\"]\n", + "required_env_vars = [\"MORPHEUS_ROOT\", \"NVIDIA_API_KEY\"]\n", "\n", "if (not all([var in os.environ for var in required_env_vars])):\n", "\n", @@ -265,29 +265,6 @@ "NVIDIA NIM microservices are OpenAI API compliant to maximize usability, so we will be using the openai with package as a wrapped to make API calls.\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "67c8c1d6", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Create the connection object. The API key and organization ID are read from the environment variables NGC_API_KEY and\n", - "# NGC_ORG_ID respectively\n", - "api_key = os.getenv(\"OPENAI_API_KEY\")\n", - "base_url = os.getenv(\"OPENAI_BASE_URL\")\n", - "\n", - "llm_client = OpenAI(\n", - " base_url = base_url,\n", - " api_key = api_key\n", - ")\n", - "\n", - "\n", - "print(f\"Connected to LLM hosted at: {base_url}\")" - ] - }, { "cell_type": "markdown", "id": "4639ba36-02ac-4a94-9815-182bbda12c38", @@ -320,19 +297,11 @@ }, "outputs": [], "source": [ - "completion = llm_client.chat.completions.create(\n", - " model=\"mistralai/mixtral-8x22b-instruct-v0.1\",\n", - " messages=[{\"role\":\"user\",\"content\":\"How can one determine if a CVE is vulnerable in a specific environment?\"}], #Prompt goes here\n", - " temperature=0.5,\n", - " top_p=1,\n", - " max_tokens=1024,\n", - " stream=True\n", - " \n", - ")\n", - "\n", - "for chunk in completion:\n", - " if chunk.choices[0].delta.content is not None:\n", - " print(chunk.choices[0].delta.content, end=\"\")" + "# Create the connection object. The API key and organization ID are read from the environment variables NGC_API_KEY and\n", + "# NGC_ORG_ID respectively\n", + "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n", + "completion = ChatNVIDIA(model=\"mistralai/mixtral-8x7b-instruct-v0.1\").invoke(\"How can one determine if a CVE is vulnerable in a specific environment?\")\n", + "print(completion.content) #API Key is read from environment variable" ] }, { @@ -357,19 +326,8 @@ "outputs": [], "source": [ "# # UNCOMMENT to try different models\n", - "# completion = llm_client.chat.completions.create(\n", - "# model=\"\",\n", - "# messages=[{\"role\":\"user\",\"content\":\"How can one determine if a CVE is vulnerable in a specific environment?\"}], #Prompt goes here\n", - "# temperature=0.5,\n", - "# top_p=1,\n", - "# max_tokens=1024,\n", - "# stream=True\n", - " \n", - "# )\n", - "\n", - "# for chunk in completion:\n", - "# if chunk.choices[0].delta.content is not None:\n", - "# print(chunk.choices[0].delta.content, end=\"\")" + "# completion = ChatNVIDIA(model=\"\").invoke(\"How can one determine if a CVE is vulnerable in a specific environment?\")\n", + "# print(completion.content) " ] }, { @@ -408,19 +366,10 @@ "# # UNCOMMENT to try different parameters\n", "# # Analyze output of the model for different value of temperature and top_k\n", "# for temp in [0.0, 0.5, 0.7, 0.9]:\n", - "# completion = llm_client.chat.completions.create(\n", - "# model=\"mistralai/mixtral-8x22b-instruct-v0.1\",\n", - "# messages=[{\"role\":\"user\",\"content\":\"How can one determine if a CVE is vulnerable in a specific environment?\"}], #Prompt goes here\n", - "# temperature=temp,\n", - "# top_p=1,\n", - "# max_tokens=1024,\n", - "# stream=True\n", - "\n", - "# )\n", - "\n", - "# for chunk in completion:\n", - "# if chunk.choices[0].delta.content is not None:\n", - "# print(chunk.choices[0].delta.content, end=\"\")\n", + "# completion = ChatNVIDIA(\n", + "# model=\"mistralai/mixtral-8x7b-instruct-v0.1\",\n", + "# temperature=temp).invoke(\"How can one determine if a CVE is vulnerable in a specific environment?\")\n", + "# print(completion.content)\n", " \n", "# print(\"\\n-----\\n\")" ] @@ -466,19 +415,9 @@ "formatted_prompt = \"{persona} {query}\".format(\n", " persona=security_expert_persona, query=\"How can one determine if a CVE is vulnerable in a specific environment?\")\n", "\n", - "completion = llm_client.chat.completions.create(\n", - " model=\"mistralai/mixtral-8x22b-instruct-v0.1\",\n", - " messages=[{\"role\":\"user\",\"content\":formatted_prompt}], \n", - " temperature=0.5,\n", - " top_p=1,\n", - " max_tokens=1024,\n", - " stream=True\n", - " \n", - ")\n", + "completion = ChatNVIDIA(model=\"mistralai/mixtral-8x7b-instruct-v0.1\").invoke(formatted_prompt)\n", "\n", - "for chunk in completion:\n", - " if chunk.choices[0].delta.content is not None:\n", - " print(chunk.choices[0].delta.content, end=\"\")" + "print(completion.content)" ] }, { @@ -508,19 +447,9 @@ "# formatted_prompt = \"{persona} {query}\".format(\n", "# persona=persona, query=\"How can one determine if a CVE is vulnerable in a specific environment?\")\n", "\n", - "# completion = llm_client.chat.completions.create(\n", - "# model=\"mistralai/mixtral-8x22b-instruct-v0.1\",\n", - "# messages=[{\"role\":\"user\",\"content\":formatted_prompt}], \n", - "# temperature=0.5,\n", - "# top_p=1,\n", - "# max_tokens=1024,\n", - "# stream=True\n", - " \n", - "# )\n", + "# completion = ChatNVIDIA(model=\"mistralai/mixtral-8x7b-instruct-v0.1\").invoke(formatted_prompt)\n", "\n", - "# for chunk in completion:\n", - "# if chunk.choices[0].delta.content is not None:\n", - "# print(chunk.choices[0].delta.content, end=\"\")" + "# print(completion.content)" ] }, { @@ -558,19 +487,9 @@ " persona=\"You are helpful cybersecurity expert with an IQ of 140.\",\n", " query=\"How can I determine if my specific environment is affected by CVE-2023-47248?\")\n", "\n", - "completion = llm_client.chat.completions.create(\n", - " model=\"mistralai/mixtral-8x22b-instruct-v0.1\",\n", - " messages=[{\"role\":\"user\",\"content\":formatted_prompt}], \n", - " temperature=0.5,\n", - " top_p=1,\n", - " max_tokens=1024,\n", - " stream=True\n", - " \n", - ")\n", + "completion = ChatNVIDIA(model=\"mistralai/mixtral-8x7b-instruct-v0.1\").invoke(formatted_prompt)\n", "\n", - "for chunk in completion:\n", - " if chunk.choices[0].delta.content is not None:\n", - " print(chunk.choices[0].delta.content, end=\"\")" + "print(completion.content)" ] }, { @@ -659,17 +578,9 @@ "\n", "formatted_prompt = prompt_template.format(**PYARROW_CVE_INTEL)\n", "\n", - "completion = llm_client.chat.completions.create(\n", - " model=\"mistralai/mixtral-8x22b-instruct-v0.1\",\n", - " messages=[{\"role\":\"user\",\"content\":formatted_prompt}],\n", - " temperature=0.5,\n", - " top_p=1,\n", - " max_tokens=1024,\n", - " stream=False\n", - " \n", - ")\n", + "completion = ChatNVIDIA(model=\"mistralai/mixtral-8x7b-instruct-v0.1\").invoke(formatted_prompt)\n", "\n", - "print(completion.choices[0].message.content)" + "print(completion.content)" ] }, { @@ -718,19 +629,9 @@ "\n", "# for model in models_to_try:\n", "# print(f\"\\n----\\nModel: {model}\")\n", - "# completion = llm_client.chat.completions.create(\n", - "# model=model,\n", - "# messages=[{\"role\":\"user\",\"content\":formatted_prompt}],\n", - "# temperature=0.5,\n", - "# top_p=1,\n", - "# max_tokens=1024,\n", - "# stream=True\n", + "# completion = ChatNVIDIA(model=model).invoke(formatted_prompt)\n", "\n", - "# )\n", - "\n", - "# for chunk in completion:\n", - "# if chunk.choices[0].delta.content is not None:\n", - "# print(chunk.choices[0].delta.content, end=\"\")" + "# print(completion.content)" ] }, { @@ -789,15 +690,7 @@ }, "outputs": [], "source": [ - "is_properly_formatted_list(\n", - " llm_client.chat.completions.create(\n", - " model=\"meta/llama3-70b-instruct\",\n", - " messages=[{\"role\":\"user\",\"content\":formatted_prompt}], #Prompt goes here\n", - " temperature=0.5,\n", - " top_p=1,\n", - " max_tokens=1024,\n", - " stream=False\n", - ").choices[0].message.content)" + "is_properly_formatted_list(ChatNVIDIA(model=\"meta/llama3-70b-instruct\").invoke(formatted_prompt).content)" ] }, { @@ -864,14 +757,7 @@ "\n", "formatted_prompt = zero_shot_template.format(checklist=unparsable_list)\n", "\n", - "model_output = llm_client.chat.completions.create(\n", - " model=\"meta/llama3-70b-instruct\",\n", - " messages=[{\"role\":\"user\",\"content\":formatted_prompt}], \n", - " temperature=0.5,\n", - " top_p=1,\n", - " max_tokens=1024,\n", - " stream=False\n", - ").choices[0].message.content\n", + "model_output = ChatNVIDIA(model=\"meta/llama3-70b-instruct\").invoke(formatted_prompt).content\n", "\n", "print(model_output)" ] @@ -920,14 +806,7 @@ "\n", "formatted_prompt = one_shot_template.format(checklist=unparsable_list)\n", "\n", - "model_output = llm_client.chat.completions.create(\n", - " model=\"meta/llama3-70b-instruct\",\n", - " messages=[{\"role\":\"user\",\"content\":formatted_prompt}],\n", - " temperature=0.5,\n", - " top_p=1,\n", - " max_tokens=1024,\n", - " stream=False\n", - ").choices[0].message.content\n", + "model_output = ChatNVIDIA(model=\"meta/llama3-70b-instruct\").invoke(formatted_prompt).content\n", "\n", "print(model_output)" ] @@ -966,21 +845,15 @@ "outputs": [], "source": [ "# # UNCOMMENT to try with an enumerated list\n", - "# enumerated_list = \"\"\"1. Check if the vulnerable package, PyArrow, is installed in the container.\n", - "# 2. If the vulnerable package is installed, check the version of the package. If it is before 14.0.1, the vulnerability is present.\n", - "# 3. Check if the container has any exposed IPC or Parquet readers.\"\"\"\n", - "\n", - "# formatted_prompt = one_shot_template.format(checklist=unparsable_list)\n", - "\n", - "# model_output = llm_client.chat.completions.create(\n", - "# model=\"meta/llama3-70b-instruct\",\n", - "# messages=[{\"role\":\"user\",\"content\":formatted_prompt}], #Prompt goes here\n", - "# temperature=0.5,\n", - "# top_p=1,\n", - "# max_tokens=1024,\n", - "# stream=False\n", - "# ).choices[0].message.content\n", - "# print(model_output)" + "enumerated_list = \"\"\"1. Check if the vulnerable package, PyArrow, is installed in the container.\n", + "2. If the vulnerable package is installed, check the version of the package. If it is before 14.0.1, the vulnerability is present.\n", + "3. Check if the container has any exposed IPC or Parquet readers.\"\"\"\n", + "\n", + "formatted_prompt = one_shot_template.format(checklist=unparsable_list)\n", + "\n", + "model_output = ChatNVIDIA(model=\"meta/llama3-70b-instruct\").invoke(formatted_prompt).content\n", + "\n", + "print(model_output)" ] }, { @@ -1084,14 +957,7 @@ "\n", "formatted_prompt = few_shot_prompt_template.format(**PYARROW_CVE_INTEL)\n", "\n", - "model_output = llm_client.chat.completions.create(\n", - " model=\"meta/llama3-70b-instruct\",\n", - " messages=[{\"role\":\"user\",\"content\":formatted_prompt}],\n", - " temperature=0.5,\n", - " top_p=1,\n", - " max_tokens=1024,\n", - " stream=False\n", - ").choices[0].message.content\n", + "model_output = ChatNVIDIA(model=\"meta/llama3-70b-instruct\").invoke(formatted_prompt).content\n", "\n", "print(model_output)" ] @@ -1136,14 +1002,7 @@ "# \"Check for PyArrow: Verify if your project uses the PyArrow library, which is the affected package. \"\n", "# \"If PyArrow is not a dependency in your project, then your code is not vulnerable to this CVE.\"\n", "# )\n", - "# print(llm_client.chat.completions.create(\n", - "# model=\"meta/llama3-70b-instruct\",\n", - "# messages=[{\"role\":\"user\",\"content\":example_checklist_item}], #Prompt goes here\n", - "# temperature=0.5,\n", - "# top_p=1,\n", - "# max_tokens=1024,\n", - "# stream=False\n", - "# ).choices[0].message.content)" + "# print(ChatNVIDIA(model=\"meta/llama3-70b-instruct\").invoke(example_checklist_item).content)" ] }, { @@ -1183,14 +1042,7 @@ "\n", "# updated_prompt = repeat_substring(few_shot_prompt_template, example_one_start_index, example_one_end_index, 14)\n", "\n", - "# model_output = llm_client.chat.completions.create(\n", - "# model=\"meta/llama3-8b-instruct\",\n", - "# messages=[{\"role\":\"user\",\"content\":updated_prompt}], #Prompt goes here\n", - "# temperature=0.5,\n", - "# top_p=1,\n", - "# max_tokens=1024,\n", - "# stream=False\n", - "# ).choices[0].message.content\n", + "# model_output = ChatNVIDIA(model=\"meta/llama3-70b-instruct\").invoke(updated_prompt).content\n", "# print(model_output)" ] }, @@ -1514,6 +1366,7 @@ "execution_count": null, "id": "506b5fb3", "metadata": { + "scrolled": true, "tags": [] }, "outputs": [], @@ -1700,7 +1553,7 @@ " },\n", " \"base_url\": \"https://integrate.api.nvidia.com/v1\",\n", " \"model_name\": \"meta/llama3-70b-instruct\",\n", - " \"temperature\": 0.02\n", + " \"temperature\": 0.1\n", " },\n", " \"sbom\": {\n", " \"data_file\":\n", @@ -1759,7 +1612,7 @@ "cve_details = cve_details_template.format(**PYARROW_CVE_INTEL)\n", "\n", "# Now run the pipeline with a specified CVE description\n", - "await run_cve_pipeline(pipeline_config, engine_config, [cve_details], retry_bad_input=False)" + "await run_cve_pipeline(pipeline_config, engine_config, [cve_details])" ] }, { @@ -2072,7 +1925,7 @@ "from morpheus.utils.http_utils import HTTPMethod\n", "\n", "\n", - "async def run_cve_pipeline_microservice(p_config: Config, e_config: EngineConfig):\n", + "async def run_cve_pipeline_microservice(p_config: Config, e_config: EngineConfig, retry_bad_input = True):\n", "\n", " completion_task = {\"task_type\": \"completion\", \"task_dict\": {\"input_keys\": [\"cve_info\"], }}\n", "\n", @@ -2099,7 +1952,7 @@ " pipe.add_stage(\n", " DeserializeStage(p_config, message_type=ControlMessage, task_type=\"llm_engine\", task_payload=completion_task))\n", "\n", - " pipe.add_stage(LLMEngineStage(p_config, engine=build_cve_llm_engine(e_config)))\n", + " pipe.add_stage(LLMEngineStage(p_config, engine=build_cve_llm_engine(e_config, retry_bad_input)))\n", "\n", " sink = pipe.add_stage(InMemorySinkStage(p_config))\n", "\n", @@ -2128,10 +1981,7 @@ "execution_count": null, "id": "f4f87ac9", "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - }, + "scrolled": true, "tags": [] }, "outputs": [],