Skip to content

Commit 1fecfea

Browse files
3coinspre-commit-ci[bot]dlqqq
authored
Model parameters option to pass in model tuning, arbitrary parameters (#430)
* Endpoint args for SM endpoints * Added model and endpoints kwargs options. * Added configurable option for model parameters. * Updated magics, added model_parameters, removed model_kwargs and endpoint_kwargs. * Fixes %ai error for SM endpoints. * Fixed docs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 430 fixes (#2) * log configured model_parameters * fix markdown formatting in docs * fix single quotes and use preferred traitlets CLI syntax --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: david qiu <[email protected]>
1 parent 9f69469 commit 1fecfea

File tree

10 files changed

+208
-25
lines changed

10 files changed

+208
-25
lines changed

docs/source/users/index.md

Lines changed: 113 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -855,30 +855,138 @@ The `--response-path` option is a [JSONPath](https://goessner.net/articles/JsonP
855855

856856
## Configuration
857857

858-
You can specify an allowlist, to only allow only a certain list of providers, or a blocklist, to block some providers.
858+
You can specify an allowlist, to only allow only a certain list of providers, or
859+
a blocklist, to block some providers.
859860

860861
### Blocklisting providers
861-
This configuration allows for blocking specific providers in the settings panel. This list takes precedence over the allowlist in the next section.
862+
863+
This configuration allows for blocking specific providers in the settings panel.
864+
This list takes precedence over the allowlist in the next section.
862865

863866
```
864867
jupyter lab --AiExtension.blocked_providers=openai
865868
```
866869

867-
To block more than one provider in the block-list, repeat the runtime configuration.
870+
To block more than one provider in the block-list, repeat the runtime
871+
configuration.
868872

869873
```
870874
jupyter lab --AiExtension.blocked_providers=openai --AiExtension.blocked_providers=ai21
871875
```
872876

873877
### Allowlisting providers
874-
This configuration allows for filtering the list of providers in the settings panel to only an allowlisted set of providers.
878+
879+
This configuration allows for filtering the list of providers in the settings
880+
panel to only an allowlisted set of providers.
875881

876882
```
877883
jupyter lab --AiExtension.allowed_providers=openai
878884
```
879885

880-
To allow more than one provider in the allowlist, repeat the runtime configuration.
886+
To allow more than one provider in the allowlist, repeat the runtime
887+
configuration.
881888

882889
```
883890
jupyter lab --AiExtension.allowed_providers=openai --AiExtension.allowed_providers=ai21
884891
```
892+
893+
### Model parameters
894+
895+
This configuration allows specifying arbitrary parameters that are unpacked and
896+
passed to the provider class. This is useful for passing parameters such as
897+
model tuning that affect the response generation by the model. This is also an
898+
appropriate place to pass in custom attributes required by certain
899+
providers/models.
900+
901+
The accepted value is a dictionary, with top level keys as the model id
902+
(provider:model_id), and value should be any arbitrary dictionary which is
903+
unpacked and passed as-is to the provider class.
904+
905+
#### Configuring as a startup option
906+
907+
In this sample, the `bedrock` provider will be created with the value for
908+
`model_kwargs` when `ai21.j2-mid-v1` model is selected.
909+
910+
```bash
911+
jupyter lab --AiExtension.model_parameters bedrock:ai21.j2-mid-v1='{"model_kwargs":{"maxTokens":200}}'
912+
```
913+
914+
Note the usage of single quotes surrounding the dictionary to escape the double
915+
quotes. This is required in some shells. The above will result in the following
916+
LLM class to be generated.
917+
918+
```python
919+
BedrockProvider(model_kwargs={"maxTokens":200}, ...)
920+
```
921+
922+
Here is another example, where `anthropic` provider will be created with the
923+
values for `max_tokens` and `temperature`, when `claude-2` model is selected.
924+
925+
926+
```bash
927+
jupyter lab --AiExtension.model_parameters anthropic:claude-2='{"max_tokens":1024,"temperature":0.9}'
928+
```
929+
930+
The above will result in the following LLM class to be generated.
931+
932+
```python
933+
AnthropicProvider(max_tokens=1024, temperature=0.9, ...)
934+
```
935+
936+
To pass multiple sets of model parameters for multiple models in the
937+
command-line, you can append them as additional arguments to
938+
`--AiExtension.model_parameters`, as shown below.
939+
940+
```bash
941+
jupyter lab \
942+
--AiExtension.model_parameters bedrock:ai21.j2-mid-v1='{"model_kwargs":{"maxTokens":200}}' \
943+
--AiExtension.model_parameters anthropic:claude-2='{"max_tokens":1024,"temperature":0.9}'
944+
```
945+
946+
However, for more complex configuration, we highly recommend that you specify
947+
this in a dedicated configuration file. We will describe how to do so in the
948+
following section.
949+
950+
#### Configuring as a config file
951+
952+
This configuration can also be specified in a config file in json format. The
953+
file should be named `jupyter_jupyter_ai_config.json` and saved in a path that
954+
JupyterLab can pick from. You can find this path by running `jupyter --paths`
955+
command, and picking one of the paths from the `config` section.
956+
957+
Here is an example of running the `jupyter --paths` command.
958+
959+
```bash
960+
(jupyter-ai-lab4) ➜ jupyter --paths
961+
config:
962+
/opt/anaconda3/envs/jupyter-ai-lab4/etc/jupyter
963+
/Users/3coins/.jupyter
964+
/Users/3coins/.local/etc/jupyter
965+
/usr/3coins/etc/jupyter
966+
/etc/jupyter
967+
data:
968+
/opt/anaconda3/envs/jupyter-ai-lab4/share/jupyter
969+
/Users/3coins/Library/Jupyter
970+
/Users/3coins/.local/share/jupyter
971+
/usr/local/share/jupyter
972+
/usr/share/jupyter
973+
runtime:
974+
/Users/3coins/Library/Jupyter/runtime
975+
```
976+
977+
Here is an example for configuring the `bedrock` provider for `ai21.j2-mid-v1`
978+
model.
979+
980+
```json
981+
{
982+
"AiExtension": {
983+
"model_parameters": {
984+
"bedrock:ai21.j2-mid-v1": {
985+
"model_kwargs": {
986+
"maxTokens": 200
987+
}
988+
}
989+
}
990+
}
991+
}
992+
```

packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import ClassVar, List, Type
1+
from typing import ClassVar, List
22

33
from jupyter_ai_magics.providers import (
44
AuthStrategy,
@@ -12,7 +12,6 @@
1212
HuggingFaceHubEmbeddings,
1313
OpenAIEmbeddings,
1414
)
15-
from langchain.embeddings.base import Embeddings
1615
from pydantic import BaseModel, Extra
1716

1817

packages/jupyter-ai-magics/jupyter_ai_magics/magics.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -392,9 +392,11 @@ def handle_error(self, args: ErrorArgs):
392392

393393
prompt = f"Explain the following error:\n\n{last_error}"
394394
# Set CellArgs based on ErrorArgs
395-
cell_args = CellArgs(
396-
type="root", model_id=args.model_id, format=args.format, reset=False
397-
)
395+
values = args.dict()
396+
values["type"] = "root"
397+
values["reset"] = False
398+
cell_args = CellArgs(**values)
399+
398400
return self.run_ai_cell(cell_args, prompt)
399401

400402
def _append_exchange_openai(self, prompt: str, output: str):
@@ -518,16 +520,9 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
518520
provider_params["request_schema"] = args.request_schema
519521
provider_params["response_path"] = args.response_path
520522

521-
# Validate that the request schema is well-formed JSON
522-
try:
523-
json.loads(args.request_schema)
524-
except json.JSONDecodeError as e:
525-
raise ValueError(
526-
"request-schema must be valid JSON. "
527-
f"Error at line {e.lineno}, column {e.colno}: {e.msg}"
528-
) from None
523+
model_parameters = json.loads(args.model_parameters)
529524

530-
provider = Provider(**provider_params)
525+
provider = Provider(**provider_params, **model_parameters)
531526

532527
# Apply a prompt template.
533528
prompt = provider.get_prompt_template(args.format).format(prompt=prompt)

packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from typing import Literal, Optional, get_args
23

34
import click
@@ -32,12 +33,21 @@
3233
+ "does nothing with other providers."
3334
)
3435

36+
MODEL_PARAMETERS_SHORT_OPTION = "-m"
37+
MODEL_PARAMETERS_LONG_OPTION = "--model-parameters"
38+
MODEL_PARAMETERS_HELP = (
39+
"A JSON value that specifies extra values that will be passed "
40+
"to the model. The accepted value parsed to a dict, unpacked "
41+
"and passed as-is to the provider class."
42+
)
43+
3544

3645
class CellArgs(BaseModel):
3746
type: Literal["root"] = "root"
3847
model_id: str
3948
format: FORMAT_CHOICES_TYPE
4049
reset: bool
50+
model_parameters: Optional[str]
4151
# The following parameters are required only for SageMaker models
4252
region_name: Optional[str]
4353
request_schema: Optional[str]
@@ -49,6 +59,7 @@ class ErrorArgs(BaseModel):
4959
type: Literal["error"] = "error"
5060
model_id: str
5161
format: FORMAT_CHOICES_TYPE
62+
model_parameters: Optional[str]
5263
# The following parameters are required only for SageMaker models
5364
region_name: Optional[str]
5465
request_schema: Optional[str]
@@ -93,6 +104,19 @@ def get_help(self, ctx):
93104
click.echo(super().get_help(ctx))
94105

95106

107+
def verify_json_value(ctx, param, value):
108+
if not value:
109+
return value
110+
try:
111+
json.loads(value)
112+
except json.JSONDecodeError as e:
113+
raise ValueError(
114+
f"{param.get_error_hint(ctx)} must be valid JSON. "
115+
f"Error at line {e.lineno}, column {e.colno}: {e.msg}"
116+
)
117+
return value
118+
119+
96120
@click.command()
97121
@click.argument("model_id")
98122
@click.option(
@@ -120,13 +144,22 @@ def get_help(self, ctx):
120144
REQUEST_SCHEMA_LONG_OPTION,
121145
required=False,
122146
help=REQUEST_SCHEMA_HELP,
147+
callback=verify_json_value,
123148
)
124149
@click.option(
125150
RESPONSE_PATH_SHORT_OPTION,
126151
RESPONSE_PATH_LONG_OPTION,
127152
required=False,
128153
help=RESPONSE_PATH_HELP,
129154
)
155+
@click.option(
156+
MODEL_PARAMETERS_SHORT_OPTION,
157+
MODEL_PARAMETERS_LONG_OPTION,
158+
required=False,
159+
help=MODEL_PARAMETERS_HELP,
160+
callback=verify_json_value,
161+
default="{}",
162+
)
130163
def cell_magic_parser(**kwargs):
131164
"""
132165
Invokes a language model identified by MODEL_ID, with the prompt being
@@ -166,13 +199,22 @@ def line_magic_parser():
166199
REQUEST_SCHEMA_LONG_OPTION,
167200
required=False,
168201
help=REQUEST_SCHEMA_HELP,
202+
callback=verify_json_value,
169203
)
170204
@click.option(
171205
RESPONSE_PATH_SHORT_OPTION,
172206
RESPONSE_PATH_LONG_OPTION,
173207
required=False,
174208
help=RESPONSE_PATH_HELP,
175209
)
210+
@click.option(
211+
MODEL_PARAMETERS_SHORT_OPTION,
212+
MODEL_PARAMETERS_LONG_OPTION,
213+
required=False,
214+
help=MODEL_PARAMETERS_HELP,
215+
callback=verify_json_value,
216+
default="{}",
217+
)
176218
def error_subparser(**kwargs):
177219
"""
178220
Explains the most recent error. Takes the same options (except -r) as

packages/jupyter-ai-magics/jupyter_ai_magics/providers.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,17 @@
55
import io
66
import json
77
from concurrent.futures import ThreadPoolExecutor
8-
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union
8+
from typing import (
9+
Any,
10+
ClassVar,
11+
Coroutine,
12+
Dict,
13+
List,
14+
Literal,
15+
Mapping,
16+
Optional,
17+
Union,
18+
)
919

1020
from jsonpath_ng import parse
1121
from langchain.chat_models import (
@@ -621,6 +631,7 @@ def __init__(self, *args, **kwargs):
621631
content_handler = JsonContentHandler(
622632
request_schema=request_schema, response_path=response_path
623633
)
634+
624635
super().__init__(*args, **kwargs, content_handler=content_handler)
625636

626637
async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:

packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def __init__(self, retriever, *args, **kwargs):
3636
def create_llm_chain(
3737
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
3838
):
39-
self.llm = provider(**provider_params)
39+
model_parameters = self.get_model_parameters(provider, provider_params)
40+
self.llm = provider(**provider_params, **model_parameters)
4041
memory = ConversationBufferWindowMemory(
4142
memory_key="chat_history", return_messages=True, k=2
4243
)

packages/jupyter-ai/jupyter_ai/chat_handlers/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import traceback
44

55
# necessary to prevent circular import
6-
from typing import TYPE_CHECKING, Dict, Optional, Type
6+
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
77
from uuid import uuid4
88

99
from jupyter_ai.config_manager import ConfigManager, Logger
@@ -23,10 +23,12 @@ def __init__(
2323
log: Logger,
2424
config_manager: ConfigManager,
2525
root_chat_handlers: Dict[str, "RootChatHandler"],
26+
model_parameters: Dict[str, Dict],
2627
):
2728
self.log = log
2829
self.config_manager = config_manager
2930
self._root_chat_handlers = root_chat_handlers
31+
self.model_parameters = model_parameters
3032
self.parser = argparse.ArgumentParser()
3133
self.llm = None
3234
self.llm_params = None
@@ -122,6 +124,13 @@ def get_llm_chain(self):
122124
self.llm_params = lm_provider_params
123125
return self.llm_chain
124126

127+
def get_model_parameters(
128+
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
129+
):
130+
return self.model_parameters.get(
131+
f"{provider.id}:{provider_params['model_id']}", {}
132+
)
133+
125134
def create_llm_chain(
126135
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
127136
):

packages/jupyter-ai/jupyter_ai/chat_handlers/default.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
4040
def create_llm_chain(
4141
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
4242
):
43-
llm = provider(**provider_params)
43+
model_parameters = self.get_model_parameters(provider, provider_params)
44+
llm = provider(**provider_params, **model_parameters)
4445

4546
if llm.is_chat_provider:
4647
prompt_template = ChatPromptTemplate.from_messages(

packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,9 @@ def __init__(self, root_dir: str, *args, **kwargs):
226226
def create_llm_chain(
227227
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
228228
):
229-
llm = provider(**provider_params)
229+
model_parameters = self.get_model_parameters(provider, provider_params)
230+
llm = provider(**provider_params, **model_parameters)
231+
230232
self.llm = llm
231233
return llm
232234

0 commit comments

Comments
 (0)