Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ OPENAI_API_KEY=
OPENAI_API_BASE=
OPENAI_MODEL=
SILICON_API_KEY=
SILICON_MODEL=
SILICON_MODEL=
MINIMAX_API_KEY=
MINIMAX_MODEL=
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ python -m mindsearch.app --lang en --model_format internlm_server --search_engin
- `--model_format`: format of the model.
- `internlm_server` for InternLM2.5-7b-chat with local server. (InternLM2.5-7b-chat has been better optimized for Chinese.)
- `gpt4` for GPT4.
- `minimax` for [MiniMax](https://platform.minimaxi.com/) (MiniMax-M1 by default, configurable via `MINIMAX_MODEL` env var).
if you want to use other models, please modify [models](./mindsearch/agent/models.py)
- `--search_engine`: Search engine.
- `DuckDuckGoSearch` for search engine for DuckDuckGo.
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ python -m mindsearch.app --lang en --model_format internlm_server --search_engin
- `--model_format`: 模型的格式。
- `internlm_server` 为 InternLM2.5-7b-chat 本地服务器。
- `gpt4` 为 GPT4。
- `minimax` 为 [MiniMax](https://platform.minimaxi.com/)(默认使用 MiniMax-M1 模型,可通过 `MINIMAX_MODEL` 环境变量配置)。
如果您想使用其他模型,请修改 [models](./mindsearch/agent/models.py)
- `--search_engine`: 搜索引擎。
- `DuckDuckGoSearch` 为 DuckDuckGo 搜索引擎。
Expand Down
147 changes: 147 additions & 0 deletions mindsearch/agent/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import json
import os
import time
import traceback
import warnings

import requests
from dotenv import load_dotenv
from lagent.llms import (
GPTAPI,
Expand All @@ -9,6 +14,130 @@
LMDeployServer,
)


class MiniMaxAPI(GPTAPI):
"""GPTAPI subclass for MiniMax's OpenAI-compatible API.

MiniMax models use a standard OpenAI-compatible chat completions format.
This subclass extends GPTAPI to handle MiniMax model types that are not
recognized by the base class's model routing logic.
"""

def generate_request_data(self, model_type, messages, gen_params,
json_mode=False):
gen_params = gen_params.copy()
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
if max_tokens <= 0:
return '', ''
header = {'content-type': 'application/json'}
gen_params['max_tokens'] = max_tokens
if 'stop_words' in gen_params:
gen_params['stop'] = gen_params.pop('stop_words')
gen_params.pop('repetition_penalty', None)
if 'top_k' in gen_params:
warnings.warn(
'`top_k` parameter is not supported by MiniMax API.',
DeprecationWarning)
gen_params.pop('top_k')
gen_params.pop('skip_special_tokens', None)
gen_params.pop('session_id', None)
data = {
'model': model_type,
'messages': messages,
'n': 1,
**gen_params,
}
if json_mode:
data['response_format'] = {'type': 'json_object'}
return header, data

def _stream_chat(self, messages, **gen_params):
"""Override to fix finish_reason handling for MiniMax streaming.

MiniMax streaming chunks may omit ``finish_reason`` in non-final
chunks (standard OpenAI SSE behaviour), but the base class assumes
the key is always present.
"""

def streaming(raw_response):
for chunk in raw_response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
if chunk:
decoded = chunk.decode('utf-8')
if decoded.startswith('data: [DONE]'):
return
if decoded[:5] == 'data:':
decoded = decoded[5:].lstrip()
else:
continue
try:
response = json.loads(decoded)
choice = response['choices'][0]
if choice.get('finish_reason') == 'stop':
return
yield choice.get('delta', {}).get('content', '')
except Exception as exc:
msg = (f'response {decoded} lead to exception '
f'of {str(exc)}')
self.logger.error(msg)
raise Exception(msg) from exc

assert isinstance(messages, list)
header, data = self.generate_request_data(
model_type=self.model_type,
messages=messages,
gen_params=gen_params,
json_mode=self.json_mode)

max_num_retries, errmsg = 0, ''
while max_num_retries < self.retry:
if len(self.invalid_keys) == len(self.keys):
raise RuntimeError('All keys have insufficient quota.')
while True:
self.key_ctr += 1
if self.key_ctr == len(self.keys):
self.key_ctr = 0
if self.keys[self.key_ctr] not in self.invalid_keys:
break
key = self.keys[self.key_ctr]
header['Authorization'] = f'Bearer {key}'
response = dict()
try:
raw_response = requests.post(
self.url,
headers=header,
data=json.dumps(data),
proxies=self.proxies,
stream=True)
return streaming(raw_response)
except requests.ConnectionError:
errmsg = 'Got connection error ' + str(
traceback.format_exc())
self.logger.error(errmsg)
continue
except KeyError:
if 'error' in response:
if response['error']['code'] == 'rate_limit_exceeded':
time.sleep(1)
continue
elif response['error']['code'] == 'insufficient_quota':
self.invalid_keys.add(key)
self.logger.warn(
f'insufficient_quota key: {key}')
continue
errmsg = ('Find error message in response: '
+ str(response['error']))
self.logger.error(errmsg)
except Exception as error:
errmsg = str(error) + '\n' + str(
traceback.format_exc())
self.logger.error(errmsg)
max_num_retries += 1
raise RuntimeError(
'Calling OpenAI failed after retrying for '
f'{max_num_retries} times. Check the logs for '
f'details. errmsg: {errmsg}')

internlm_server = dict(
type=LMDeployServer,
path="internlm/internlm2_5-7b-chat",
Expand Down Expand Up @@ -93,3 +222,21 @@
repetition_penalty=1.02,
stop_words=["<|im_end|>"],
)

# MiniMax Cloud API (OpenAI-compatible)
# Get your API key from https://platform.minimaxi.com/
minimax = dict(
type=MiniMaxAPI,
model_type=os.environ.get("MINIMAX_MODEL", "MiniMax-M1"),
key=os.environ.get("MINIMAX_API_KEY", "YOUR MINIMAX API KEY"),
api_base="https://api.minimax.io/v1/chat/completions",
meta_template=[
dict(role="system", api_role="system"),
dict(role="user", api_role="user"),
dict(role="assistant", api_role="assistant"),
dict(role="environment", api_role="system"),
],
top_p=0.8,
temperature=0.01,
max_new_tokens=8192,
)
Empty file added tests/__init__.py
Empty file.
68 changes: 68 additions & 0 deletions tests/test_minimax_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Integration tests for MiniMax model in MindSearch.

These tests require a valid MINIMAX_API_KEY environment variable.
Skip with: pytest -m "not integration"
"""
import os
import unittest

import pytest

MINIMAX_API_KEY = os.environ.get("MINIMAX_API_KEY", "")
SKIP_REASON = "MINIMAX_API_KEY not set"


@pytest.mark.integration
@unittest.skipUnless(MINIMAX_API_KEY, SKIP_REASON)
class TestMiniMaxAPIIntegration(unittest.TestCase):
"""Integration tests that call the MiniMax API."""

def test_minimax_gptapi_basic_chat(self):
"""Test that MiniMax GPTAPI can complete a basic chat request."""
from copy import deepcopy

from lagent.utils import create_object

from mindsearch.agent import models as llm_factory

cfg = deepcopy(llm_factory.minimax)
llm = create_object(cfg)
# GPTAPI.chat expects a list of message dicts
response = llm.chat(
[{"role": "user", "content": "Say hello in one word."}]
)
self.assertIsInstance(response, str)
self.assertGreater(len(response.strip()), 0)

def test_minimax_gptapi_streaming(self):
"""Test that MiniMax GPTAPI supports streaming responses."""
from copy import deepcopy

from lagent.utils import create_object

from mindsearch.agent import models as llm_factory

cfg = deepcopy(llm_factory.minimax)
llm = create_object(cfg)
chunks = []
for chunk in llm.stream_chat(
[{"role": "user", "content": "Count from 1 to 3."}]
):
chunks.append(chunk)
self.assertGreater(len(chunks), 0)

def test_minimax_init_agent_resolves(self):
"""Test that init_agent can create an LLM object for minimax format."""
from copy import deepcopy

from lagent.utils import create_object

from mindsearch.agent import models as llm_factory

cfg = deepcopy(getattr(llm_factory, "minimax"))
llm = create_object(cfg)
self.assertIsNotNone(llm)


if __name__ == "__main__":
unittest.main()
Loading