diff --git a/.github/workflows/test_lemonade.yml b/.github/workflows/test_lemonade.yml index 6073d04..b7bdbbe 100644 --- a/.github/workflows/test_lemonade.yml +++ b/.github/workflows/test_lemonade.yml @@ -45,6 +45,7 @@ jobs: shell: bash -el {0} run: | pylint src/lemonade --rcfile .pylintrc --disable E0401 + pylint examples --rcfile .pylintrc --disable E0401,E0611 --jobs=1 - name: Test HF+CPU server if: runner.os == 'Windows' timeout-minutes: 10 diff --git a/examples/lemonade/demos/chat/chat_hybrid.py b/examples/lemonade/demos/chat/chat_hybrid.py index 8b770ff..d4e3c8f 100644 --- a/examples/lemonade/demos/chat/chat_hybrid.py +++ b/examples/lemonade/demos/chat/chat_hybrid.py @@ -1,6 +1,6 @@ import sys from threading import Thread, Event -from transformers import StoppingCriteria, StoppingCriteriaList +from transformers import StoppingCriteriaList from lemonade.tools.chat import StopOnEvent from lemonade import leap from lemonade.tools.ort_genai.oga import OrtGenaiStreamer diff --git a/examples/lemonade/demos/chat/chat_start.py b/examples/lemonade/demos/chat/chat_start.py index 22724f1..a094c83 100644 --- a/examples/lemonade/demos/chat/chat_start.py +++ b/examples/lemonade/demos/chat/chat_start.py @@ -1,9 +1,9 @@ import sys from threading import Thread, Event -from transformers import StoppingCriteriaList -from lemonade.tools.chat import StopOnEvent from queue import Queue from time import sleep +from transformers import StoppingCriteriaList +from lemonade.tools.chat import StopOnEvent class TextStreamer: @@ -43,6 +43,7 @@ def generate_placeholder( Not needed once we integrate with LEAP. """ + # pylint: disable=line-too-long response = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum." for word in response.split(" "): diff --git a/examples/lemonade/demos/search/search_start.py b/examples/lemonade/demos/search/search_start.py index 705cea7..8249e29 100644 --- a/examples/lemonade/demos/search/search_start.py +++ b/examples/lemonade/demos/search/search_start.py @@ -1,11 +1,10 @@ import sys from threading import Thread, Event +from queue import Queue +from time import sleep from transformers import StoppingCriteriaList from lemonade.tools.chat import StopOnEvent -# These imports are not needed when we add the LLM -from queue import Queue -from time import sleep employee_handbook = """ 1. You will work very hard every day.\n diff --git a/src/lemonade/tools/ort_genai/oga.py b/src/lemonade/tools/ort_genai/oga.py index 818a340..e859788 100644 --- a/src/lemonade/tools/ort_genai/oga.py +++ b/src/lemonade/tools/ort_genai/oga.py @@ -113,6 +113,7 @@ def generate( self, input_ids, max_new_tokens=512, + min_new_tokens=0, do_sample=True, top_k=50, top_p=1.0, @@ -135,6 +136,7 @@ def generate( params.pad_token_id = pad_token_id max_length = len(input_ids) + max_new_tokens + min_length = len(input_ids) + min_new_tokens if use_oga_pre_6_api: params.input_ids = input_ids @@ -147,7 +149,7 @@ def generate( top_p=search_config.get("top_p", top_p), temperature=search_config.get("temperature", temperature), max_length=max_length, - min_length=0, + min_length=min_length, early_stopping=search_config.get("early_stopping", False), length_penalty=search_config.get("length_penalty", 1.0), num_beams=search_config.get("num_beams", 1), @@ -167,7 +169,7 @@ def generate( top_p=top_p, temperature=temperature, max_length=max_length, - min_length=max_length, + min_length=min_length, ) params.try_graph_capture_with_max_batch_size(1) diff --git a/src/lemonade/tools/ort_genai/oga_bench.py b/src/lemonade/tools/ort_genai/oga_bench.py index ba9d8a1..93ae746 100644 --- a/src/lemonade/tools/ort_genai/oga_bench.py +++ b/src/lemonade/tools/ort_genai/oga_bench.py @@ -161,7 +161,11 @@ def run( model.generate(input_ids, max_new_tokens=output_tokens) for _ in tqdm.tqdm(range(iterations), desc="iterations"): - outputs = model.generate(input_ids, max_new_tokens=output_tokens) + outputs = model.generate( + input_ids, + max_new_tokens=output_tokens, + min_new_tokens=output_tokens, + ) token_len = len(outputs[0]) - input_ids_len diff --git a/src/turnkeyml/version.py b/src/turnkeyml/version.py index 3a223dd..4682e61 100644 --- a/src/turnkeyml/version.py +++ b/src/turnkeyml/version.py @@ -1 +1 @@ -__version__ = "5.0.2" +__version__ = "5.0.3"