-
Notifications
You must be signed in to change notification settings - Fork 759
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Wendong-Fan <[email protected]> Co-authored-by: Wendong <[email protected]>
- Loading branch information
1 parent
26100a9
commit 53b9308
Showing
7 changed files
with
330 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
|
||
import os | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from openai import OpenAI, Stream | ||
|
||
from camel.configs import OPENAI_API_PARAMS | ||
from camel.messages import OpenAIMessage | ||
from camel.models import BaseModelBackend | ||
from camel.types import ChatCompletion, ChatCompletionChunk, ModelType | ||
from camel.utils import ( | ||
BaseTokenCounter, | ||
OpenAITokenCounter, | ||
model_api_key_required, | ||
) | ||
|
||
|
||
class ZhipuAIModel(BaseModelBackend): | ||
r"""ZhipuAI API in a unified BaseModelBackend interface.""" | ||
|
||
def __init__( | ||
self, | ||
model_type: ModelType, | ||
model_config_dict: Dict[str, Any], | ||
api_key: Optional[str] = None, | ||
url: Optional[str] = None, | ||
) -> None: | ||
r"""Constructor for ZhipuAI backend. | ||
Args: | ||
model_type (ModelType): Model for which a backend is created, | ||
such as GLM_* series. | ||
model_config_dict (Dict[str, Any]): A dictionary that will | ||
be fed into openai.ChatCompletion.create(). | ||
api_key (Optional[str]): The API key for authenticating with the | ||
ZhipuAI service. (default: :obj:`None`) | ||
""" | ||
super().__init__(model_type, model_config_dict) | ||
self._url = url or os.environ.get("ZHIPUAI_API_BASE_URL") | ||
self._api_key = api_key or os.environ.get("ZHIPUAI_API_KEY") | ||
self._client = OpenAI( | ||
timeout=60, | ||
max_retries=3, | ||
api_key=self._api_key, | ||
base_url=self._url, | ||
) | ||
self._token_counter: Optional[BaseTokenCounter] = None | ||
|
||
@model_api_key_required | ||
def run( | ||
self, | ||
messages: List[OpenAIMessage], | ||
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]: | ||
r"""Runs inference of OpenAI chat completion. | ||
Args: | ||
messages (List[OpenAIMessage]): Message list with the chat history | ||
in OpenAI API format. | ||
Returns: | ||
Union[ChatCompletion, Stream[ChatCompletionChunk]]: | ||
`ChatCompletion` in the non-stream mode, or | ||
`Stream[ChatCompletionChunk]` in the stream mode. | ||
""" | ||
# Use OpenAI cilent as interface call ZhipuAI | ||
# Reference: https://open.bigmodel.cn/dev/api#openai_sdk | ||
response = self._client.chat.completions.create( | ||
messages=messages, | ||
model=self.model_type.value, | ||
**self.model_config_dict, | ||
) | ||
return response | ||
|
||
@property | ||
def token_counter(self) -> BaseTokenCounter: | ||
r"""Initialize the token counter for the model backend. | ||
Returns: | ||
OpenAITokenCounter: The token counter following the model's | ||
tokenization style. | ||
""" | ||
|
||
if not self._token_counter: | ||
# It's a temporary setting for token counter. | ||
self._token_counter = OpenAITokenCounter(ModelType.GPT_3_5_TURBO) | ||
return self._token_counter | ||
|
||
def check_model_config(self): | ||
r"""Check whether the model configuration contains any | ||
unexpected arguments to OpenAI API. | ||
Raises: | ||
ValueError: If the model configuration dictionary contains any | ||
unexpected arguments to OpenAI API. | ||
""" | ||
for param in self.model_config_dict: | ||
if param not in OPENAI_API_PARAMS: | ||
raise ValueError( | ||
f"Unexpected argument `{param}` is " | ||
"input into OpenAI model backend." | ||
) | ||
pass | ||
|
||
@property | ||
def stream(self) -> bool: | ||
r"""Returns whether the model is in stream mode, which sends partial | ||
results each time. | ||
Returns: | ||
bool: Whether the model is in stream mode. | ||
""" | ||
return self.model_config_dict.get('stream', False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
|
||
from camel.agents import ChatAgent | ||
from camel.configs import ChatGPTConfig | ||
from camel.messages import BaseMessage | ||
from camel.types import ModelType | ||
|
||
# Define system message | ||
sys_msg = BaseMessage.make_assistant_message( | ||
role_name="Assistant", | ||
content="You are a helpful assistant.", | ||
) | ||
|
||
# Set model config | ||
model_config = ChatGPTConfig( | ||
temperature=0.2, top_p=0.9 | ||
) # temperature=,top_p here can not be 1 or 0. | ||
|
||
# Set agent | ||
camel_agent = ChatAgent( | ||
sys_msg, | ||
model_config=model_config, | ||
model_type=ModelType.GLM_4, | ||
) | ||
camel_agent.reset() | ||
|
||
user_msg = BaseMessage.make_user_message( | ||
role_name="User", | ||
content="I want to practice my legs today." | ||
"Help me make a fitness and diet plan", | ||
) | ||
|
||
# Get response information | ||
response = camel_agent.step(user_msg) | ||
print(response.msgs[0].content) | ||
''' | ||
=============================================================================== | ||
Certainly! Focusing on leg workouts can help improve strength, endurance, and | ||
overall lower-body fitness. Here's a sample fitness | ||
and diet plan for leg training: | ||
**Fitness Plan:** | ||
1. **Warm-Up:** | ||
- 5-10 minutes of light cardio (jogging, cycling, or jumping jacks) | ||
- Leg swings (forward and backward) | ||
- Hip circles | ||
2. **Strength Training:** | ||
- Squats: 3 sets of 8-12 reps | ||
- Deadlifts: 3 sets of 8-12 reps | ||
- Lunges: 3 sets of 10-12 reps per leg | ||
- Leg press: 3 sets of 10-12 reps | ||
- Calf raises: 3 sets of 15-20 reps | ||
3. **Cardio:** | ||
- Hill sprints: 5-8 reps of 30-second sprints | ||
- Cycling or stationary biking: 20-30 minutes at moderate intensity | ||
4. **Cool Down:** | ||
- Stretching (focus on the legs, hip flexors, and hamstrings) | ||
- Foam rolling (optional) | ||
**Diet Plan:** | ||
1. **Breakfast:** | ||
- Greek yogurt with mixed berries and a tablespoon of chia seeds | ||
- Whole-grain toast with avocado | ||
2. **Snack:** | ||
- A banana with a tablespoon of natural peanut butter | ||
3. **Lunch:** | ||
- Grilled chicken breast with quinoa and steamed vegetables | ||
- A side of mixed greens with a light vinaigrette | ||
4. **Snack:** | ||
- A serving of mixed nuts and dried fruits | ||
5. **Dinner:** | ||
- Baked salmon with sweet potato and roasted asparagus | ||
- A side of lentil soup or a bean salad | ||
6. **Post-Workout Snack:** | ||
- A protein shake or a serving of cottage cheese with fruit | ||
7. **Hydration:** | ||
- Drink plenty of water throughout the day to | ||
stay hydrated, especially after workouts. | ||
**Tips:** | ||
- Ensure you get enough rest and recovery, as leg workouts | ||
can be demanding on the body. | ||
- Listen to your body and adjust the weights and reps | ||
according to your fitness level. | ||
- Make sure to include a variety of nutrients in your diet to | ||
support muscle recovery and overall health. | ||
- Consult a fitness professional or trainer if you need personalized | ||
guidance or have any pre-existing health conditions. | ||
Remember, consistency is key to seeing results, so stick to | ||
your plan and modify it as needed to suit your goals and progress. | ||
=============================================================================== | ||
''' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
import re | ||
|
||
import pytest | ||
|
||
from camel.configs import ChatGPTConfig, OpenSourceConfig | ||
from camel.models import ZhipuAIModel | ||
from camel.types import ModelType | ||
from camel.utils import OpenAITokenCounter | ||
|
||
|
||
@pytest.mark.model_backend | ||
@pytest.mark.parametrize( | ||
"model_type", | ||
[ | ||
ModelType.GLM_3_TURBO, | ||
ModelType.GLM_4, | ||
ModelType.GLM_4V, | ||
], | ||
) | ||
def test_zhipuai_model(model_type): | ||
model_config_dict = ChatGPTConfig().__dict__ | ||
model = ZhipuAIModel(model_type, model_config_dict) | ||
assert model.model_type == model_type | ||
assert model.model_config_dict == model_config_dict | ||
assert isinstance(model.token_counter, OpenAITokenCounter) | ||
assert isinstance(model.model_type.value_for_tiktoken, str) | ||
assert isinstance(model.model_type.token_limit, int) | ||
|
||
|
||
@pytest.mark.model_backend | ||
def test_zhipuai_model_unexpected_argument(): | ||
model_type = ModelType.GLM_4V | ||
model_config = OpenSourceConfig( | ||
model_path="vicuna-7b-v1.5", | ||
server_url="http://localhost:8000/v1", | ||
) | ||
model_config_dict = model_config.__dict__ | ||
|
||
with pytest.raises( | ||
ValueError, | ||
match=re.escape( | ||
( | ||
"Unexpected argument `model_path` is " | ||
"input into OpenAI model backend." | ||
) | ||
), | ||
): | ||
_ = ZhipuAIModel(model_type, model_config_dict) |