Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sourcery refactored main branch #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,18 @@ async def upload_file(repo_name: str, file: UploadFile = File(...)):
raise HTTPException(status_code=400, detail=f'directory is not a git repository: {repo_path}')

file_path = os.path.join(repo_path, file.filename)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function upload_file refactored with the following changes:

async with aiofiles.open(file_path, 'wb') as f:
content = await file.read()
await f.write(content)

repo = Repo(repo_path)
repo.git.add([file_path])
repo.index.commit('Add file through API')
msg = {'filename': file.filename, 'message': 'file uploaded and committed successfully'}

return msg

return {
'filename': file.filename,
'message': 'file uploaded and committed successfully',
}
except Exception as err:
raise HTTPException(status_code=400, detail=str(err))

Expand All @@ -89,17 +89,16 @@ async def read_file_by_name(repo_name: str, prompt_name: str, raw: Optional[bool
raise HTTPException(status_code=400, detail=f'directory is not a git repository: {repo_path}')

file_path = os.path.join(repo_path, f'{prompt_name}.yml')

if os.path.exists(file_path):
data = parse_yaml(file_path)

if raw:
return {'prompt': data.get('prompt')}
else:
return FileResponse(file_path, media_type='application/x-yaml')

else:

if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail=f'File not found: {file_path}')
data = parse_yaml(file_path)

return (
{'prompt': data.get('prompt')}
if raw
else FileResponse(file_path, media_type='application/x-yaml')
)
Comment on lines -92 to +101
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function read_file_by_name refactored with the following changes:



@app.get('/{repo_name}/_uuid/{prompt_uuid}', responses={200: {'content': {'application/x-yaml': {}}}})
Expand Down
21 changes: 10 additions & 11 deletions tools/contentctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,28 +110,27 @@ def ask_for_input(field_name, field_type, is_required, default_value=None):

def create_prompt():
# create the YAML map
map_info = {}
map_info = {'title': ask_for_input('title', 'str', True)}

map_info['title'] = ask_for_input('title', 'str', True)
map_info['uuid'] = str(uuid.uuid4())
map_info['description'] = ask_for_input('description', 'str', True)
map_info['category'] = ask_for_input('category', 'str', True)
map_info['provider'] = ask_for_input('provider', 'str', False)
map_info['model'] = ask_for_input('model', 'str', False)

# model settings with default values
model_settings_fields = ['temperature', 'top_k', 'top_p', 'max_tokens', 'stream', 'presence_penalty', 'frequency_penalty']
model_settings_types = ['float', 'int', 'float', 'int', 'bool', 'float', 'float']
model_settings_defaults = [0.8, None, 1, None, False, 0.0, 0.0]

model_settings = {field: ask_for_input(field, ftype, False, default) for field, ftype, default in zip(model_settings_fields, model_settings_types, model_settings_defaults)}

# only add model_settings to map if it's not empty
if any(model_settings.values()):
map_info['model_settings'] = model_settings

map_info['prompt'] = ask_for_input('prompt', 'str', True)

Comment on lines -113 to +133
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function create_prompt refactored with the following changes:

# Sequence fields
seq_fields = ['references', 'associations', 'packs', 'tags', 'input_variables']
for field in seq_fields:
Expand Down Expand Up @@ -216,7 +215,7 @@ def display_stats():
action='store',
help='create new prompt'
)

Comment on lines -219 to +218
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 219-269 refactored with the following changes:

parser.add_argument(
'-s', '--stats',
action='store',
Expand All @@ -233,7 +232,7 @@ def display_stats():

if args.init:
if not args.config:
rprint(f'[bold red](error)[/bold red] config file required for initialization')
rprint('[bold red](error)[/bold red] config file required for initialization')
sys.exit(1)

config = Config(args.config)
Expand All @@ -254,7 +253,7 @@ def display_stats():

collect_stats_from_dir(args.stats)
display_stats()

if args.langchain:
if not os.path.exists(args.langchain):
rprint(f'[bold red](error)[/bold red] template does not exist: {args.langchain}')
Expand All @@ -264,7 +263,7 @@ def display_stats():
if original is None or langchain_template is None:
rprint(f'[bold red](error)[/bold red] failed to convert prompt: {args.langchain}')
sys.exit(1)

rprint(f'[bold green](status)[/bold green] successfully converted template: {args.langchain}')
rprint(f'[bold orange3]LangChain PromptTemplate[/bold orange3]')
rprint('[bold orange3]LangChain PromptTemplate[/bold orange3]')
print(json.dumps(langchain_template.dict(), indent=2))
12 changes: 8 additions & 4 deletions tools/load_from_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,22 @@ def get_template(self, full_repo_name, full_prompt_name) -> str:
print(f'(error) failed to parse repo name - exception: {err}')
print('name should be in the format of username/repo')
return prompt_data

url = f'{self.base_url}/{repo_user}/{repo_name}/main/prompts/{full_prompt_name}.yml'
full_prompt_name = full_prompt_name + '.yml' if not full_prompt_name.endswith('.yml') else full_prompt_name
full_prompt_name = (
f'{full_prompt_name}.yml'
if not full_prompt_name.endswith('.yml')
else full_prompt_name
)

print(f'(status) retrieving template: {url}')

try:
response = requests.get(url)
if response.status_code != 200:
print(f'(error) error retrieving template - non 200 status code: {response.status_code}')
return prompt_data

Comment on lines -28 to +43
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function PromptLoader.get_template refactored with the following changes:

prompt_data = yaml.safe_load(response.text)

except Exception as err:
Expand Down