diff --git a/.env.example b/.env.example index 86c7798..8107454 100644 --- a/.env.example +++ b/.env.example @@ -2,4 +2,6 @@ OPENAI_API_KEY= OPENAI_API_BASE= OPENAI_MODEL= SILICON_API_KEY= -SILICON_MODEL= \ No newline at end of file +SILICON_MODEL= +MINIMAX_API_KEY= +MINIMAX_MODEL= \ No newline at end of file diff --git a/README.md b/README.md index 2710c61..69b5ed5 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ python -m mindsearch.app --lang en --model_format internlm_server --search_engin - `--model_format`: format of the model. - `internlm_server` for InternLM2.5-7b-chat with local server. (InternLM2.5-7b-chat has been better optimized for Chinese.) - `gpt4` for GPT4. + - `minimax` for [MiniMax](https://platform.minimaxi.com/) (MiniMax-M1 by default, configurable via `MINIMAX_MODEL` env var). if you want to use other models, please modify [models](./mindsearch/agent/models.py) - `--search_engine`: Search engine. - `DuckDuckGoSearch` for search engine for DuckDuckGo. diff --git a/README_zh-CN.md b/README_zh-CN.md index 073a1b9..a22b39a 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -58,6 +58,7 @@ python -m mindsearch.app --lang en --model_format internlm_server --search_engin - `--model_format`: 模型的格式。 - `internlm_server` 为 InternLM2.5-7b-chat 本地服务器。 - `gpt4` 为 GPT4。 + - `minimax` 为 [MiniMax](https://platform.minimaxi.com/)(默认使用 MiniMax-M1 模型,可通过 `MINIMAX_MODEL` 环境变量配置)。 如果您想使用其他模型,请修改 [models](./mindsearch/agent/models.py) - `--search_engine`: 搜索引擎。 - `DuckDuckGoSearch` 为 DuckDuckGo 搜索引擎。 diff --git a/mindsearch/agent/models.py b/mindsearch/agent/models.py index 4682858..1539878 100644 --- a/mindsearch/agent/models.py +++ b/mindsearch/agent/models.py @@ -1,5 +1,10 @@ +import json import os +import time +import traceback +import warnings +import requests from dotenv import load_dotenv from lagent.llms import ( GPTAPI, @@ -9,6 +14,130 @@ LMDeployServer, ) + +class MiniMaxAPI(GPTAPI): + """GPTAPI subclass for MiniMax's OpenAI-compatible API. + + MiniMax models use a standard OpenAI-compatible chat completions format. + This subclass extends GPTAPI to handle MiniMax model types that are not + recognized by the base class's model routing logic. + """ + + def generate_request_data(self, model_type, messages, gen_params, + json_mode=False): + gen_params = gen_params.copy() + max_tokens = min(gen_params.pop('max_new_tokens'), 4096) + if max_tokens <= 0: + return '', '' + header = {'content-type': 'application/json'} + gen_params['max_tokens'] = max_tokens + if 'stop_words' in gen_params: + gen_params['stop'] = gen_params.pop('stop_words') + gen_params.pop('repetition_penalty', None) + if 'top_k' in gen_params: + warnings.warn( + '`top_k` parameter is not supported by MiniMax API.', + DeprecationWarning) + gen_params.pop('top_k') + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + data = { + 'model': model_type, + 'messages': messages, + 'n': 1, + **gen_params, + } + if json_mode: + data['response_format'] = {'type': 'json_object'} + return header, data + + def _stream_chat(self, messages, **gen_params): + """Override to fix finish_reason handling for MiniMax streaming. + + MiniMax streaming chunks may omit ``finish_reason`` in non-final + chunks (standard OpenAI SSE behaviour), but the base class assumes + the key is always present. + """ + + def streaming(raw_response): + for chunk in raw_response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b'\n'): + if chunk: + decoded = chunk.decode('utf-8') + if decoded.startswith('data: [DONE]'): + return + if decoded[:5] == 'data:': + decoded = decoded[5:].lstrip() + else: + continue + try: + response = json.loads(decoded) + choice = response['choices'][0] + if choice.get('finish_reason') == 'stop': + return + yield choice.get('delta', {}).get('content', '') + except Exception as exc: + msg = (f'response {decoded} lead to exception ' + f'of {str(exc)}') + self.logger.error(msg) + raise Exception(msg) from exc + + assert isinstance(messages, list) + header, data = self.generate_request_data( + model_type=self.model_type, + messages=messages, + gen_params=gen_params, + json_mode=self.json_mode) + + max_num_retries, errmsg = 0, '' + while max_num_retries < self.retry: + if len(self.invalid_keys) == len(self.keys): + raise RuntimeError('All keys have insufficient quota.') + while True: + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + if self.keys[self.key_ctr] not in self.invalid_keys: + break + key = self.keys[self.key_ctr] + header['Authorization'] = f'Bearer {key}' + response = dict() + try: + raw_response = requests.post( + self.url, + headers=header, + data=json.dumps(data), + proxies=self.proxies, + stream=True) + return streaming(raw_response) + except requests.ConnectionError: + errmsg = 'Got connection error ' + str( + traceback.format_exc()) + self.logger.error(errmsg) + continue + except KeyError: + if 'error' in response: + if response['error']['code'] == 'rate_limit_exceeded': + time.sleep(1) + continue + elif response['error']['code'] == 'insufficient_quota': + self.invalid_keys.add(key) + self.logger.warn( + f'insufficient_quota key: {key}') + continue + errmsg = ('Find error message in response: ' + + str(response['error'])) + self.logger.error(errmsg) + except Exception as error: + errmsg = str(error) + '\n' + str( + traceback.format_exc()) + self.logger.error(errmsg) + max_num_retries += 1 + raise RuntimeError( + 'Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}') + internlm_server = dict( type=LMDeployServer, path="internlm/internlm2_5-7b-chat", @@ -93,3 +222,21 @@ repetition_penalty=1.02, stop_words=["<|im_end|>"], ) + +# MiniMax Cloud API (OpenAI-compatible) +# Get your API key from https://platform.minimaxi.com/ +minimax = dict( + type=MiniMaxAPI, + model_type=os.environ.get("MINIMAX_MODEL", "MiniMax-M1"), + key=os.environ.get("MINIMAX_API_KEY", "YOUR MINIMAX API KEY"), + api_base="https://api.minimax.io/v1/chat/completions", + meta_template=[ + dict(role="system", api_role="system"), + dict(role="user", api_role="user"), + dict(role="assistant", api_role="assistant"), + dict(role="environment", api_role="system"), + ], + top_p=0.8, + temperature=0.01, + max_new_tokens=8192, +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_minimax_integration.py b/tests/test_minimax_integration.py new file mode 100644 index 0000000..4a3d2f7 --- /dev/null +++ b/tests/test_minimax_integration.py @@ -0,0 +1,68 @@ +"""Integration tests for MiniMax model in MindSearch. + +These tests require a valid MINIMAX_API_KEY environment variable. +Skip with: pytest -m "not integration" +""" +import os +import unittest + +import pytest + +MINIMAX_API_KEY = os.environ.get("MINIMAX_API_KEY", "") +SKIP_REASON = "MINIMAX_API_KEY not set" + + +@pytest.mark.integration +@unittest.skipUnless(MINIMAX_API_KEY, SKIP_REASON) +class TestMiniMaxAPIIntegration(unittest.TestCase): + """Integration tests that call the MiniMax API.""" + + def test_minimax_gptapi_basic_chat(self): + """Test that MiniMax GPTAPI can complete a basic chat request.""" + from copy import deepcopy + + from lagent.utils import create_object + + from mindsearch.agent import models as llm_factory + + cfg = deepcopy(llm_factory.minimax) + llm = create_object(cfg) + # GPTAPI.chat expects a list of message dicts + response = llm.chat( + [{"role": "user", "content": "Say hello in one word."}] + ) + self.assertIsInstance(response, str) + self.assertGreater(len(response.strip()), 0) + + def test_minimax_gptapi_streaming(self): + """Test that MiniMax GPTAPI supports streaming responses.""" + from copy import deepcopy + + from lagent.utils import create_object + + from mindsearch.agent import models as llm_factory + + cfg = deepcopy(llm_factory.minimax) + llm = create_object(cfg) + chunks = [] + for chunk in llm.stream_chat( + [{"role": "user", "content": "Count from 1 to 3."}] + ): + chunks.append(chunk) + self.assertGreater(len(chunks), 0) + + def test_minimax_init_agent_resolves(self): + """Test that init_agent can create an LLM object for minimax format.""" + from copy import deepcopy + + from lagent.utils import create_object + + from mindsearch.agent import models as llm_factory + + cfg = deepcopy(getattr(llm_factory, "minimax")) + llm = create_object(cfg) + self.assertIsNotNone(llm) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_minimax_models.py b/tests/test_minimax_models.py new file mode 100644 index 0000000..876514e --- /dev/null +++ b/tests/test_minimax_models.py @@ -0,0 +1,200 @@ +"""Unit tests for MiniMax model configuration in MindSearch.""" +import os +import unittest +from unittest.mock import patch + + +class TestMiniMaxModelConfig(unittest.TestCase): + """Test MiniMax model preset configuration.""" + + def test_minimax_preset_exists(self): + """Test that the minimax preset is defined in models module.""" + from mindsearch.agent import models as llm_factory + + self.assertTrue( + hasattr(llm_factory, "minimax"), + "minimax preset should be defined in models.py", + ) + + def test_minimax_preset_is_dict(self): + """Test that the minimax preset is a dictionary.""" + from mindsearch.agent import models as llm_factory + + self.assertIsInstance(llm_factory.minimax, dict) + + def test_minimax_uses_minimaxapi_type(self): + """Test that MiniMax preset uses the MiniMaxAPI subclass.""" + from mindsearch.agent import models as llm_factory + from lagent.llms import GPTAPI + + self.assertTrue( + issubclass(llm_factory.minimax["type"], GPTAPI), + "MiniMaxAPI should be a subclass of GPTAPI", + ) + self.assertEqual(llm_factory.minimax["type"].__name__, "MiniMaxAPI") + + def test_minimax_api_base(self): + """Test that MiniMax preset uses the correct API endpoint.""" + from mindsearch.agent import models as llm_factory + + self.assertEqual( + llm_factory.minimax["api_base"], + "https://api.minimax.io/v1/chat/completions", + ) + + def test_minimax_default_model(self): + """Test that MiniMax preset defaults to MiniMax-M1.""" + with patch.dict(os.environ, {}, clear=False): + # Remove MINIMAX_MODEL if set + env = os.environ.copy() + env.pop("MINIMAX_MODEL", None) + with patch.dict(os.environ, env, clear=True): + # Re-import to pick up the env + import importlib + from mindsearch.agent import models + + importlib.reload(models) + self.assertEqual(models.minimax["model_type"], "MiniMax-M1") + + def test_minimax_custom_model_via_env(self): + """Test that MINIMAX_MODEL env var overrides the default model.""" + with patch.dict(os.environ, {"MINIMAX_MODEL": "MiniMax-M1-highspeed"}): + import importlib + from mindsearch.agent import models + + importlib.reload(models) + self.assertEqual( + models.minimax["model_type"], "MiniMax-M1-highspeed" + ) + + def test_minimax_api_key_from_env(self): + """Test that MINIMAX_API_KEY env var is used for authentication.""" + test_key = "test-minimax-key-12345" + with patch.dict(os.environ, {"MINIMAX_API_KEY": test_key}): + import importlib + from mindsearch.agent import models + + importlib.reload(models) + self.assertEqual(models.minimax["key"], test_key) + + def test_minimax_meta_template(self): + """Test that MiniMax preset has proper role mapping.""" + from mindsearch.agent import models as llm_factory + + meta = llm_factory.minimax["meta_template"] + self.assertIsInstance(meta, list) + roles = {item["role"] for item in meta} + self.assertIn("system", roles) + self.assertIn("user", roles) + self.assertIn("assistant", roles) + self.assertIn("environment", roles) + + def test_minimax_temperature_within_range(self): + """Test that MiniMax temperature is within valid range (0, 1].""" + from mindsearch.agent import models as llm_factory + + temp = llm_factory.minimax["temperature"] + self.assertGreater(temp, 0) + self.assertLessEqual(temp, 1) + + def test_minimax_has_max_new_tokens(self): + """Test that MiniMax preset specifies max_new_tokens.""" + from mindsearch.agent import models as llm_factory + + self.assertIn("max_new_tokens", llm_factory.minimax) + self.assertGreater(llm_factory.minimax["max_new_tokens"], 0) + + def test_minimax_no_unsupported_params(self): + """Test that MiniMax preset does not include unsupported parameters.""" + from mindsearch.agent import models as llm_factory + + # MiniMax API doesn't support repetition_penalty or stop_words + # in the same way as local InternLM models + self.assertNotIn("repetition_penalty", llm_factory.minimax) + self.assertNotIn("stop_words", llm_factory.minimax) + + +class TestMiniMaxInitAgent(unittest.TestCase): + """Test that MiniMax can be used with the agent initialization flow.""" + + def test_init_agent_accepts_minimax_format(self): + """Test that init_agent can resolve the minimax model_format.""" + from mindsearch.agent import models as llm_factory + + # Verify getattr works as used in __init__.py + cfg = getattr(llm_factory, "minimax", None) + self.assertIsNotNone(cfg, "minimax should be resolvable via getattr") + self.assertIn("type", cfg) + self.assertIn("model_type", cfg) + self.assertIn("key", cfg) + self.assertIn("api_base", cfg) + + def test_all_model_presets_have_required_fields(self): + """Test that all GPTAPI-based presets (including minimax) have required fields.""" + from mindsearch.agent import models as llm_factory + from lagent.llms import GPTAPI + + gptapi_presets = ["gpt4", "qwen", "internlm_silicon", "minimax"] + for name in gptapi_presets: + cfg = getattr(llm_factory, name, None) + if cfg is None: + continue + self.assertTrue( + issubclass(cfg["type"], GPTAPI), + f"{name} should use GPTAPI or a subclass", + ) + self.assertIn("model_type", cfg, f"{name} missing model_type") + self.assertIn("key", cfg, f"{name} missing key") + self.assertIn("api_base", cfg, f"{name} missing api_base") + + def test_minimaxapi_generate_request_data(self): + """Test that MiniMaxAPI generates correct request data.""" + from mindsearch.agent.models import MiniMaxAPI + + api = MiniMaxAPI.__new__(MiniMaxAPI) + messages = [{"role": "user", "content": "hello"}] + gen_params = {"max_new_tokens": 100, "temperature": 0.5} + header, data = api.generate_request_data( + "MiniMax-M1", messages, gen_params + ) + self.assertEqual(header["content-type"], "application/json") + self.assertEqual(data["model"], "MiniMax-M1") + self.assertEqual(data["messages"], messages) + self.assertEqual(data["max_tokens"], 100) + self.assertNotIn("max_new_tokens", data) + + def test_minimaxapi_strips_unsupported_params(self): + """Test that MiniMaxAPI removes unsupported parameters.""" + from mindsearch.agent.models import MiniMaxAPI + + api = MiniMaxAPI.__new__(MiniMaxAPI) + gen_params = { + "max_new_tokens": 100, + "repetition_penalty": 1.02, + "skip_special_tokens": True, + "session_id": 123, + } + _, data = api.generate_request_data( + "MiniMax-M1", [{"role": "user", "content": "hi"}], gen_params + ) + self.assertNotIn("repetition_penalty", data) + self.assertNotIn("skip_special_tokens", data) + self.assertNotIn("session_id", data) + + +class TestEnvExample(unittest.TestCase): + """Test that .env.example includes MiniMax configuration.""" + + def test_env_example_has_minimax_keys(self): + """Test that .env.example contains MINIMAX_API_KEY and MINIMAX_MODEL.""" + env_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), ".env.example" + ) + with open(env_path) as f: + content = f.read() + self.assertIn("MINIMAX_API_KEY", content) + self.assertIn("MINIMAX_MODEL", content) + + +if __name__ == "__main__": + unittest.main()