-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery refactored main branch #1
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,18 +66,18 @@ async def upload_file(repo_name: str, file: UploadFile = File(...)): | |
raise HTTPException(status_code=400, detail=f'directory is not a git repository: {repo_path}') | ||
|
||
file_path = os.path.join(repo_path, file.filename) | ||
|
||
async with aiofiles.open(file_path, 'wb') as f: | ||
content = await file.read() | ||
await f.write(content) | ||
|
||
repo = Repo(repo_path) | ||
repo.git.add([file_path]) | ||
repo.index.commit('Add file through API') | ||
msg = {'filename': file.filename, 'message': 'file uploaded and committed successfully'} | ||
|
||
return msg | ||
|
||
return { | ||
'filename': file.filename, | ||
'message': 'file uploaded and committed successfully', | ||
} | ||
except Exception as err: | ||
raise HTTPException(status_code=400, detail=str(err)) | ||
|
||
|
@@ -89,17 +89,16 @@ async def read_file_by_name(repo_name: str, prompt_name: str, raw: Optional[bool | |
raise HTTPException(status_code=400, detail=f'directory is not a git repository: {repo_path}') | ||
|
||
file_path = os.path.join(repo_path, f'{prompt_name}.yml') | ||
|
||
if os.path.exists(file_path): | ||
data = parse_yaml(file_path) | ||
|
||
if raw: | ||
return {'prompt': data.get('prompt')} | ||
else: | ||
return FileResponse(file_path, media_type='application/x-yaml') | ||
|
||
else: | ||
|
||
if not os.path.exists(file_path): | ||
raise HTTPException(status_code=404, detail=f'File not found: {file_path}') | ||
data = parse_yaml(file_path) | ||
|
||
return ( | ||
{'prompt': data.get('prompt')} | ||
if raw | ||
else FileResponse(file_path, media_type='application/x-yaml') | ||
) | ||
Comment on lines
-92
to
+101
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
@app.get('/{repo_name}/_uuid/{prompt_uuid}', responses={200: {'content': {'application/x-yaml': {}}}}) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -110,28 +110,27 @@ def ask_for_input(field_name, field_type, is_required, default_value=None): | |
|
||
def create_prompt(): | ||
# create the YAML map | ||
map_info = {} | ||
map_info = {'title': ask_for_input('title', 'str', True)} | ||
|
||
map_info['title'] = ask_for_input('title', 'str', True) | ||
map_info['uuid'] = str(uuid.uuid4()) | ||
map_info['description'] = ask_for_input('description', 'str', True) | ||
map_info['category'] = ask_for_input('category', 'str', True) | ||
map_info['provider'] = ask_for_input('provider', 'str', False) | ||
map_info['model'] = ask_for_input('model', 'str', False) | ||
|
||
# model settings with default values | ||
model_settings_fields = ['temperature', 'top_k', 'top_p', 'max_tokens', 'stream', 'presence_penalty', 'frequency_penalty'] | ||
model_settings_types = ['float', 'int', 'float', 'int', 'bool', 'float', 'float'] | ||
model_settings_defaults = [0.8, None, 1, None, False, 0.0, 0.0] | ||
|
||
model_settings = {field: ask_for_input(field, ftype, False, default) for field, ftype, default in zip(model_settings_fields, model_settings_types, model_settings_defaults)} | ||
|
||
# only add model_settings to map if it's not empty | ||
if any(model_settings.values()): | ||
map_info['model_settings'] = model_settings | ||
|
||
map_info['prompt'] = ask_for_input('prompt', 'str', True) | ||
|
||
Comment on lines
-113
to
+133
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
# Sequence fields | ||
seq_fields = ['references', 'associations', 'packs', 'tags', 'input_variables'] | ||
for field in seq_fields: | ||
|
@@ -216,7 +215,7 @@ def display_stats(): | |
action='store', | ||
help='create new prompt' | ||
) | ||
|
||
Comment on lines
-219
to
+218
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
parser.add_argument( | ||
'-s', '--stats', | ||
action='store', | ||
|
@@ -233,7 +232,7 @@ def display_stats(): | |
|
||
if args.init: | ||
if not args.config: | ||
rprint(f'[bold red](error)[/bold red] config file required for initialization') | ||
rprint('[bold red](error)[/bold red] config file required for initialization') | ||
sys.exit(1) | ||
|
||
config = Config(args.config) | ||
|
@@ -254,7 +253,7 @@ def display_stats(): | |
|
||
collect_stats_from_dir(args.stats) | ||
display_stats() | ||
|
||
if args.langchain: | ||
if not os.path.exists(args.langchain): | ||
rprint(f'[bold red](error)[/bold red] template does not exist: {args.langchain}') | ||
|
@@ -264,7 +263,7 @@ def display_stats(): | |
if original is None or langchain_template is None: | ||
rprint(f'[bold red](error)[/bold red] failed to convert prompt: {args.langchain}') | ||
sys.exit(1) | ||
|
||
rprint(f'[bold green](status)[/bold green] successfully converted template: {args.langchain}') | ||
rprint(f'[bold orange3]LangChain PromptTemplate[/bold orange3]') | ||
rprint('[bold orange3]LangChain PromptTemplate[/bold orange3]') | ||
print(json.dumps(langchain_template.dict(), indent=2)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,18 +25,22 @@ def get_template(self, full_repo_name, full_prompt_name) -> str: | |
print(f'(error) failed to parse repo name - exception: {err}') | ||
print('name should be in the format of username/repo') | ||
return prompt_data | ||
|
||
url = f'{self.base_url}/{repo_user}/{repo_name}/main/prompts/{full_prompt_name}.yml' | ||
full_prompt_name = full_prompt_name + '.yml' if not full_prompt_name.endswith('.yml') else full_prompt_name | ||
full_prompt_name = ( | ||
f'{full_prompt_name}.yml' | ||
if not full_prompt_name.endswith('.yml') | ||
else full_prompt_name | ||
) | ||
|
||
print(f'(status) retrieving template: {url}') | ||
|
||
try: | ||
response = requests.get(url) | ||
if response.status_code != 200: | ||
print(f'(error) error retrieving template - non 200 status code: {response.status_code}') | ||
return prompt_data | ||
|
||
Comment on lines
-28
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
prompt_data = yaml.safe_load(response.text) | ||
|
||
except Exception as err: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
upload_file
refactored with the following changes:inline-immediately-returned-variable
)