Skip to content

Commit 53b9308

Browse files
feat: add_zhipuai_model (#600)
Co-authored-by: Wendong-Fan <[email protected]> Co-authored-by: Wendong <[email protected]>
1 parent 26100a9 commit 53b9308

File tree

7 files changed

+330
-0
lines changed

7 files changed

+330
-0
lines changed

camel/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
from .openai_audio_models import OpenAIAudioModels
2020
from .openai_model import OpenAIModel
2121
from .stub_model import StubModel
22+
from .zhipuai_model import ZhipuAIModel
2223

2324
__all__ = [
2425
'BaseModelBackend',
2526
'OpenAIModel',
2627
'AnthropicModel',
2728
'StubModel',
29+
'ZhipuAIModel',
2830
'OpenSourceModel',
2931
'ModelFactory',
3032
'LiteLLMModel',

camel/models/model_factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from camel.models.open_source_model import OpenSourceModel
1919
from camel.models.openai_model import OpenAIModel
2020
from camel.models.stub_model import StubModel
21+
from camel.models.zhipuai_model import ZhipuAIModel
2122
from camel.types import ModelType
2223

2324

@@ -58,6 +59,8 @@ def create(
5859
model_class = OpenSourceModel
5960
elif model_type.is_anthropic:
6061
model_class = AnthropicModel
62+
elif model_type.is_zhipuai:
63+
model_class = ZhipuAIModel
6164
else:
6265
raise ValueError(f"Unknown model type `{model_type}` is input")
6366

camel/models/zhipuai_model.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
15+
import os
16+
from typing import Any, Dict, List, Optional, Union
17+
18+
from openai import OpenAI, Stream
19+
20+
from camel.configs import OPENAI_API_PARAMS
21+
from camel.messages import OpenAIMessage
22+
from camel.models import BaseModelBackend
23+
from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
24+
from camel.utils import (
25+
BaseTokenCounter,
26+
OpenAITokenCounter,
27+
model_api_key_required,
28+
)
29+
30+
31+
class ZhipuAIModel(BaseModelBackend):
32+
r"""ZhipuAI API in a unified BaseModelBackend interface."""
33+
34+
def __init__(
35+
self,
36+
model_type: ModelType,
37+
model_config_dict: Dict[str, Any],
38+
api_key: Optional[str] = None,
39+
url: Optional[str] = None,
40+
) -> None:
41+
r"""Constructor for ZhipuAI backend.
42+
43+
Args:
44+
model_type (ModelType): Model for which a backend is created,
45+
such as GLM_* series.
46+
model_config_dict (Dict[str, Any]): A dictionary that will
47+
be fed into openai.ChatCompletion.create().
48+
api_key (Optional[str]): The API key for authenticating with the
49+
ZhipuAI service. (default: :obj:`None`)
50+
"""
51+
super().__init__(model_type, model_config_dict)
52+
self._url = url or os.environ.get("ZHIPUAI_API_BASE_URL")
53+
self._api_key = api_key or os.environ.get("ZHIPUAI_API_KEY")
54+
self._client = OpenAI(
55+
timeout=60,
56+
max_retries=3,
57+
api_key=self._api_key,
58+
base_url=self._url,
59+
)
60+
self._token_counter: Optional[BaseTokenCounter] = None
61+
62+
@model_api_key_required
63+
def run(
64+
self,
65+
messages: List[OpenAIMessage],
66+
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
67+
r"""Runs inference of OpenAI chat completion.
68+
69+
Args:
70+
messages (List[OpenAIMessage]): Message list with the chat history
71+
in OpenAI API format.
72+
73+
Returns:
74+
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
75+
`ChatCompletion` in the non-stream mode, or
76+
`Stream[ChatCompletionChunk]` in the stream mode.
77+
"""
78+
# Use OpenAI cilent as interface call ZhipuAI
79+
# Reference: https://open.bigmodel.cn/dev/api#openai_sdk
80+
response = self._client.chat.completions.create(
81+
messages=messages,
82+
model=self.model_type.value,
83+
**self.model_config_dict,
84+
)
85+
return response
86+
87+
@property
88+
def token_counter(self) -> BaseTokenCounter:
89+
r"""Initialize the token counter for the model backend.
90+
91+
Returns:
92+
OpenAITokenCounter: The token counter following the model's
93+
tokenization style.
94+
"""
95+
96+
if not self._token_counter:
97+
# It's a temporary setting for token counter.
98+
self._token_counter = OpenAITokenCounter(ModelType.GPT_3_5_TURBO)
99+
return self._token_counter
100+
101+
def check_model_config(self):
102+
r"""Check whether the model configuration contains any
103+
unexpected arguments to OpenAI API.
104+
105+
Raises:
106+
ValueError: If the model configuration dictionary contains any
107+
unexpected arguments to OpenAI API.
108+
"""
109+
for param in self.model_config_dict:
110+
if param not in OPENAI_API_PARAMS:
111+
raise ValueError(
112+
f"Unexpected argument `{param}` is "
113+
"input into OpenAI model backend."
114+
)
115+
pass
116+
117+
@property
118+
def stream(self) -> bool:
119+
r"""Returns whether the model is in stream mode, which sends partial
120+
results each time.
121+
122+
Returns:
123+
bool: Whether the model is in stream mode.
124+
"""
125+
return self.model_config_dict.get('stream', False)

camel/types/enums.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ class ModelType(Enum):
2929
GPT_4_32K = "gpt-4-32k"
3030
GPT_4_TURBO = "gpt-4-turbo"
3131
GPT_4O = "gpt-4o"
32+
GLM_4 = "glm-4"
33+
GLM_4V = 'glm-4v'
34+
GLM_3_TURBO = "glm-3-turbo"
3235

3336
STUB = "stub"
3437

@@ -62,6 +65,15 @@ def is_openai(self) -> bool:
6265
ModelType.GPT_4O,
6366
}
6467

68+
@property
69+
def is_zhipuai(self) -> bool:
70+
r"""Returns whether this type of models is an ZhipuAI model."""
71+
return self in {
72+
ModelType.GLM_3_TURBO,
73+
ModelType.GLM_4,
74+
ModelType.GLM_4V,
75+
}
76+
6577
@property
6678
def is_open_source(self) -> bool:
6779
r"""Returns whether this type of models is open-source."""
@@ -103,6 +115,12 @@ def token_limit(self) -> int:
103115
return 128000
104116
elif self is ModelType.GPT_4O:
105117
return 128000
118+
elif self == ModelType.GLM_4:
119+
return 8192
120+
elif self == ModelType.GLM_3_TURBO:
121+
return 8192
122+
elif self == ModelType.GLM_4V:
123+
return 1024
106124
elif self is ModelType.STUB:
107125
return 4096
108126
elif self is ModelType.LLAMA_2:

camel/utils/commons.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def wrapper(self, *args, **kwargs):
5454
if not self._api_key and 'OPENAI_API_KEY' not in os.environ:
5555
raise ValueError('OpenAI API key not found.')
5656
return func(self, *args, **kwargs)
57+
elif self.model_type.is_zhipuai:
58+
if 'ZHIPUAI_API_KEY' not in os.environ:
59+
raise ValueError('ZhiPuAI API key not found.')
60+
return func(self, *args, **kwargs)
5761
elif self.model_type.is_anthropic:
5862
if not self._api_key and 'ANTHROPIC_API_KEY' not in os.environ:
5963
raise ValueError('Anthropic API key not found.')
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
15+
from camel.agents import ChatAgent
16+
from camel.configs import ChatGPTConfig
17+
from camel.messages import BaseMessage
18+
from camel.types import ModelType
19+
20+
# Define system message
21+
sys_msg = BaseMessage.make_assistant_message(
22+
role_name="Assistant",
23+
content="You are a helpful assistant.",
24+
)
25+
26+
# Set model config
27+
model_config = ChatGPTConfig(
28+
temperature=0.2, top_p=0.9
29+
) # temperature=,top_p here can not be 1 or 0.
30+
31+
# Set agent
32+
camel_agent = ChatAgent(
33+
sys_msg,
34+
model_config=model_config,
35+
model_type=ModelType.GLM_4,
36+
)
37+
camel_agent.reset()
38+
39+
user_msg = BaseMessage.make_user_message(
40+
role_name="User",
41+
content="I want to practice my legs today."
42+
"Help me make a fitness and diet plan",
43+
)
44+
45+
# Get response information
46+
response = camel_agent.step(user_msg)
47+
print(response.msgs[0].content)
48+
'''
49+
===============================================================================
50+
Certainly! Focusing on leg workouts can help improve strength, endurance, and
51+
overall lower-body fitness. Here's a sample fitness
52+
and diet plan for leg training:
53+
54+
**Fitness Plan:**
55+
56+
1. **Warm-Up:**
57+
- 5-10 minutes of light cardio (jogging, cycling, or jumping jacks)
58+
- Leg swings (forward and backward)
59+
- Hip circles
60+
61+
2. **Strength Training:**
62+
- Squats: 3 sets of 8-12 reps
63+
- Deadlifts: 3 sets of 8-12 reps
64+
- Lunges: 3 sets of 10-12 reps per leg
65+
- Leg press: 3 sets of 10-12 reps
66+
- Calf raises: 3 sets of 15-20 reps
67+
68+
3. **Cardio:**
69+
- Hill sprints: 5-8 reps of 30-second sprints
70+
- Cycling or stationary biking: 20-30 minutes at moderate intensity
71+
72+
4. **Cool Down:**
73+
- Stretching (focus on the legs, hip flexors, and hamstrings)
74+
- Foam rolling (optional)
75+
76+
**Diet Plan:**
77+
78+
1. **Breakfast:**
79+
- Greek yogurt with mixed berries and a tablespoon of chia seeds
80+
- Whole-grain toast with avocado
81+
82+
2. **Snack:**
83+
- A banana with a tablespoon of natural peanut butter
84+
85+
3. **Lunch:**
86+
- Grilled chicken breast with quinoa and steamed vegetables
87+
- A side of mixed greens with a light vinaigrette
88+
89+
4. **Snack:**
90+
- A serving of mixed nuts and dried fruits
91+
92+
5. **Dinner:**
93+
- Baked salmon with sweet potato and roasted asparagus
94+
- A side of lentil soup or a bean salad
95+
96+
6. **Post-Workout Snack:**
97+
- A protein shake or a serving of cottage cheese with fruit
98+
99+
7. **Hydration:**
100+
- Drink plenty of water throughout the day to
101+
stay hydrated, especially after workouts.
102+
103+
**Tips:**
104+
105+
- Ensure you get enough rest and recovery, as leg workouts
106+
can be demanding on the body.
107+
- Listen to your body and adjust the weights and reps
108+
according to your fitness level.
109+
- Make sure to include a variety of nutrients in your diet to
110+
support muscle recovery and overall health.
111+
- Consult a fitness professional or trainer if you need personalized
112+
guidance or have any pre-existing health conditions.
113+
114+
Remember, consistency is key to seeing results, so stick to
115+
your plan and modify it as needed to suit your goals and progress.
116+
===============================================================================
117+
'''

test/models/test_zhipuai_model.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
import re
15+
16+
import pytest
17+
18+
from camel.configs import ChatGPTConfig, OpenSourceConfig
19+
from camel.models import ZhipuAIModel
20+
from camel.types import ModelType
21+
from camel.utils import OpenAITokenCounter
22+
23+
24+
@pytest.mark.model_backend
25+
@pytest.mark.parametrize(
26+
"model_type",
27+
[
28+
ModelType.GLM_3_TURBO,
29+
ModelType.GLM_4,
30+
ModelType.GLM_4V,
31+
],
32+
)
33+
def test_zhipuai_model(model_type):
34+
model_config_dict = ChatGPTConfig().__dict__
35+
model = ZhipuAIModel(model_type, model_config_dict)
36+
assert model.model_type == model_type
37+
assert model.model_config_dict == model_config_dict
38+
assert isinstance(model.token_counter, OpenAITokenCounter)
39+
assert isinstance(model.model_type.value_for_tiktoken, str)
40+
assert isinstance(model.model_type.token_limit, int)
41+
42+
43+
@pytest.mark.model_backend
44+
def test_zhipuai_model_unexpected_argument():
45+
model_type = ModelType.GLM_4V
46+
model_config = OpenSourceConfig(
47+
model_path="vicuna-7b-v1.5",
48+
server_url="http://localhost:8000/v1",
49+
)
50+
model_config_dict = model_config.__dict__
51+
52+
with pytest.raises(
53+
ValueError,
54+
match=re.escape(
55+
(
56+
"Unexpected argument `model_path` is "
57+
"input into OpenAI model backend."
58+
)
59+
),
60+
):
61+
_ = ZhipuAIModel(model_type, model_config_dict)

0 commit comments

Comments
 (0)