Skip to content

Commit

Permalink
feat: add proofread provider
Browse files Browse the repository at this point in the history
Signed-off-by: Jana Peper <[email protected]>
  • Loading branch information
janepie committed Dec 20, 2024
1 parent 00ea14f commit 7a94d7d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
40 changes: 40 additions & 0 deletions lib/proofread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
"""A chain to proofread a text
"""

from typing import Any
from langchain.prompts import PromptTemplate
from langchain.schema.prompt_template import BasePromptTemplate
from langchain_core.runnables import Runnable


class ProofreadProcessor:
"""
A proofreading chain
"""
system_prompt: str = "You're an AI assistant tasked with proofreading the text given to you by the user."
user_prompt: BasePromptTemplate = PromptTemplate(
input_variables=["text"],
template="""
Detect all grammar and spelling mistakes of the following text in its original language. Output only the list of mistakes in bullet points.
"
{text}
"
Give me the list of all mistakes in the above text in its original language. Do not output the language. Output only the list in bullet points, nothing else, no introductory or explanatory text.
"""
)

runnable: Runnable

def __init__(self, runnable: Runnable):
self.runnable = runnable

def __call__(
self,
inputs: dict[str, Any],
) -> dict[str, Any]:
output = self.runnable.invoke({"user_prompt": self.user_prompt.format_prompt(text=inputs['input']), "system_prompt": self.system_prompt})
return {'output': output}
3 changes: 3 additions & 0 deletions lib/task_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from chat import ChatProcessor
from free_prompt import FreePromptProcessor
from headline import HeadlineProcessor
from proofread import ProofreadProcessor
from chatwithtools import ChatWithToolsProcessor
from topics import TopicsProcessor
from summarize import SummarizeProcessor
Expand Down Expand Up @@ -132,5 +133,7 @@ def generate_task_processors_for_model(file_name, task_processors):
# chains[model_name + ":core:text2text:reformulation"] = lambda: ReformulateChain(llm_chain=llm_chain(), chunk_size=chunk_size)
task_processors[model_name + ":core:text2text"] = lambda: FreePromptProcessor(generate_llm_chain(file_name))
task_processors[model_name + ":core:text2text:chat"] = lambda: ChatProcessor(generate_chat_chain(file_name))
task_processors[model_name + ":core:text2text:proofread"] = lambda: ProofreadProcessor(generate_llm_chain(file_name))
task_processors[model_name + ":core:text2text:chatwithtools"] = lambda: ChatWithToolsProcessor(generate_chat_chain(file_name))

# chains[model_name + ":core:contextwrite"] = lambda: ContextWriteChain(llm_chain=llm_chain())

0 comments on commit 7a94d7d

Please sign in to comment.