Skip to content

Commit ff1ebde

Browse files
authored
Merge pull request #42 from nextcloud/feat/add-new-providers
Feat: add new providers
2 parents 3caabc4 + 0886de3 commit ff1ebde

File tree

4 files changed

+94
-1
lines changed

4 files changed

+94
-1
lines changed

lib/change_tone.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors
2+
# SPDX-License-Identifier: AGPL-3.0-or-later
3+
"""A chain that changes the tone of a text
4+
"""
5+
6+
from typing import Any
7+
8+
from langchain.prompts import PromptTemplate
9+
from langchain.schema.prompt_template import BasePromptTemplate
10+
from langchain_core.runnables import Runnable
11+
12+
class ChangeToneProcessor():
13+
14+
runnable: Runnable
15+
16+
"""
17+
A topics chain
18+
"""
19+
system_prompt: str = "You're an AI assistant tasked with rewriting the text given to you by the user in another tone."
20+
user_prompt: BasePromptTemplate = PromptTemplate(
21+
input_variables=["text", "tone"],
22+
template="""Reformulate the following text in a " {tone} " tone in its original language without mentioning the language. Output only the reformulation, nothing else, no introductory sentence. Here is the text:
23+
24+
"
25+
{text}
26+
"
27+
28+
Output only the reformulated text, nothing else. Do not add an introductory sentence.
29+
"""
30+
)
31+
32+
def __init__(self, runnable: Runnable):
33+
self.runnable = runnable
34+
35+
36+
def __call__(self, inputs: dict[str,Any],
37+
) -> dict[str, Any]:
38+
output = self.runnable.invoke({"user_prompt": self.user_prompt.format_prompt(text=inputs['input'], tone=inputs['tone']), "system_prompt": self.system_prompt})
39+
return {'output': output}

lib/main.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from fastapi import FastAPI
1616
from nc_py_api import AsyncNextcloudApp, NextcloudApp, NextcloudException
1717
from nc_py_api.ex_app import LogLvl, persistent_storage, run_app, set_handlers
18-
from nc_py_api.ex_app.providers.task_processing import TaskProcessingProvider
18+
from nc_py_api.ex_app.providers.task_processing import TaskProcessingProvider, ShapeEnumValue
1919

2020
models_to_fetch = {
2121
"https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/4f0c246f125fc7594238ebe7beb1435a8335f519/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf": { "save_path": os.path.join(persistent_storage(), "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf") },
@@ -123,6 +123,15 @@ async def enabled_handler(enabled: bool, nc: AsyncNextcloudApp) -> str:
123123
name="Local Large language Model: " + model,
124124
task_type=task,
125125
expected_runtime=30,
126+
input_shape_enum_values= {
127+
"tone": [
128+
ShapeEnumValue(name= "Friendlier", value= "friendlier"),
129+
ShapeEnumValue(name= "More formal", value= "more formal"),
130+
ShapeEnumValue(name= "Funnier", value= "funnier"),
131+
ShapeEnumValue(name= "More casual", value= "more casual"),
132+
ShapeEnumValue(name= "More urgent", value= "more urgent"),
133+
],
134+
} if task == "core:text2text:changetone" else {}
126135
)
127136
await nc.providers.task_processing.register(provider)
128137
print(f"Registered {task_processor_name}", flush=True)

lib/proofread.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors
2+
# SPDX-License-Identifier: AGPL-3.0-or-later
3+
"""A chain to proofread a text
4+
"""
5+
6+
from typing import Any
7+
from langchain.prompts import PromptTemplate
8+
from langchain.schema.prompt_template import BasePromptTemplate
9+
from langchain_core.runnables import Runnable
10+
11+
12+
class ProofreadProcessor:
13+
"""
14+
A proofreading chain
15+
"""
16+
system_prompt: str = "You're an AI assistant tasked with proofreading the text given to you by the user."
17+
user_prompt: BasePromptTemplate = PromptTemplate(
18+
input_variables=["text"],
19+
template="""
20+
Detect all grammar and spelling mistakes of the following text in its original language. Output only the list of mistakes in bullet points.
21+
22+
"
23+
{text}
24+
"
25+
26+
Give me the list of all mistakes in the above text in its original language. Do not output the language. Output only the list in bullet points, nothing else, no introductory or explanatory text.
27+
"""
28+
)
29+
30+
runnable: Runnable
31+
32+
def __init__(self, runnable: Runnable):
33+
self.runnable = runnable
34+
35+
def __call__(
36+
self,
37+
inputs: dict[str, Any],
38+
) -> dict[str, Any]:
39+
output = self.runnable.invoke({"user_prompt": self.user_prompt.format_prompt(text=inputs['input']), "system_prompt": self.system_prompt})
40+
return {'output': output}

lib/task_processors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from chat import ChatProcessor
1717
from free_prompt import FreePromptProcessor
1818
from headline import HeadlineProcessor
19+
from proofread import ProofreadProcessor
20+
from change_tone import ChangeToneProcessor
1921
from chatwithtools import ChatWithToolsProcessor
2022
from topics import TopicsProcessor
2123
from summarize import SummarizeProcessor
@@ -132,5 +134,8 @@ def generate_task_processors_for_model(file_name, task_processors):
132134
# chains[model_name + ":core:text2text:reformulation"] = lambda: ReformulateChain(llm_chain=llm_chain(), chunk_size=chunk_size)
133135
task_processors[model_name + ":core:text2text"] = lambda: FreePromptProcessor(generate_llm_chain(file_name))
134136
task_processors[model_name + ":core:text2text:chat"] = lambda: ChatProcessor(generate_chat_chain(file_name))
137+
task_processors[model_name + ":core:text2text:proofread"] = lambda: ProofreadProcessor(generate_llm_chain(file_name))
138+
task_processors[model_name + ":core:text2text:changetone"] = lambda: ChangeToneProcessor(generate_llm_chain(file_name))
135139
task_processors[model_name + ":core:text2text:chatwithtools"] = lambda: ChatWithToolsProcessor(generate_chat_chain(file_name))
140+
136141
# chains[model_name + ":core:contextwrite"] = lambda: ContextWriteChain(llm_chain=llm_chain())

0 commit comments

Comments
 (0)