diff --git a/server/api.py b/server/api.py index 9ded9a9..22fec2f 100644 --- a/server/api.py +++ b/server/api.py @@ -66,7 +66,7 @@ async def upload_file(repo_name: str, file: UploadFile = File(...)): raise HTTPException(status_code=400, detail=f'directory is not a git repository: {repo_path}') file_path = os.path.join(repo_path, file.filename) - + async with aiofiles.open(file_path, 'wb') as f: content = await file.read() await f.write(content) @@ -74,10 +74,10 @@ async def upload_file(repo_name: str, file: UploadFile = File(...)): repo = Repo(repo_path) repo.git.add([file_path]) repo.index.commit('Add file through API') - msg = {'filename': file.filename, 'message': 'file uploaded and committed successfully'} - - return msg - + return { + 'filename': file.filename, + 'message': 'file uploaded and committed successfully', + } except Exception as err: raise HTTPException(status_code=400, detail=str(err)) @@ -89,17 +89,16 @@ async def read_file_by_name(repo_name: str, prompt_name: str, raw: Optional[bool raise HTTPException(status_code=400, detail=f'directory is not a git repository: {repo_path}') file_path = os.path.join(repo_path, f'{prompt_name}.yml') - - if os.path.exists(file_path): - data = parse_yaml(file_path) - - if raw: - return {'prompt': data.get('prompt')} - else: - return FileResponse(file_path, media_type='application/x-yaml') - - else: + + if not os.path.exists(file_path): raise HTTPException(status_code=404, detail=f'File not found: {file_path}') + data = parse_yaml(file_path) + + return ( + {'prompt': data.get('prompt')} + if raw + else FileResponse(file_path, media_type='application/x-yaml') + ) @app.get('/{repo_name}/_uuid/{prompt_uuid}', responses={200: {'content': {'application/x-yaml': {}}}}) diff --git a/tools/contentctl.py b/tools/contentctl.py index 603a03f..b0666a2 100644 --- a/tools/contentctl.py +++ b/tools/contentctl.py @@ -110,28 +110,27 @@ def ask_for_input(field_name, field_type, is_required, default_value=None): def create_prompt(): # create the YAML map - map_info = {} + map_info = {'title': ask_for_input('title', 'str', True)} - map_info['title'] = ask_for_input('title', 'str', True) map_info['uuid'] = str(uuid.uuid4()) map_info['description'] = ask_for_input('description', 'str', True) map_info['category'] = ask_for_input('category', 'str', True) map_info['provider'] = ask_for_input('provider', 'str', False) map_info['model'] = ask_for_input('model', 'str', False) - + # model settings with default values model_settings_fields = ['temperature', 'top_k', 'top_p', 'max_tokens', 'stream', 'presence_penalty', 'frequency_penalty'] model_settings_types = ['float', 'int', 'float', 'int', 'bool', 'float', 'float'] model_settings_defaults = [0.8, None, 1, None, False, 0.0, 0.0] - + model_settings = {field: ask_for_input(field, ftype, False, default) for field, ftype, default in zip(model_settings_fields, model_settings_types, model_settings_defaults)} - + # only add model_settings to map if it's not empty if any(model_settings.values()): map_info['model_settings'] = model_settings map_info['prompt'] = ask_for_input('prompt', 'str', True) - + # Sequence fields seq_fields = ['references', 'associations', 'packs', 'tags', 'input_variables'] for field in seq_fields: @@ -216,7 +215,7 @@ def display_stats(): action='store', help='create new prompt' ) - + parser.add_argument( '-s', '--stats', action='store', @@ -233,7 +232,7 @@ def display_stats(): if args.init: if not args.config: - rprint(f'[bold red](error)[/bold red] config file required for initialization') + rprint('[bold red](error)[/bold red] config file required for initialization') sys.exit(1) config = Config(args.config) @@ -254,7 +253,7 @@ def display_stats(): collect_stats_from_dir(args.stats) display_stats() - + if args.langchain: if not os.path.exists(args.langchain): rprint(f'[bold red](error)[/bold red] template does not exist: {args.langchain}') @@ -264,7 +263,7 @@ def display_stats(): if original is None or langchain_template is None: rprint(f'[bold red](error)[/bold red] failed to convert prompt: {args.langchain}') sys.exit(1) - + rprint(f'[bold green](status)[/bold green] successfully converted template: {args.langchain}') - rprint(f'[bold orange3]LangChain PromptTemplate[/bold orange3]') + rprint('[bold orange3]LangChain PromptTemplate[/bold orange3]') print(json.dumps(langchain_template.dict(), indent=2)) \ No newline at end of file diff --git a/tools/load_from_github.py b/tools/load_from_github.py index 5ddee37..1bc6141 100644 --- a/tools/load_from_github.py +++ b/tools/load_from_github.py @@ -25,18 +25,22 @@ def get_template(self, full_repo_name, full_prompt_name) -> str: print(f'(error) failed to parse repo name - exception: {err}') print('name should be in the format of username/repo') return prompt_data - + url = f'{self.base_url}/{repo_user}/{repo_name}/main/prompts/{full_prompt_name}.yml' - full_prompt_name = full_prompt_name + '.yml' if not full_prompt_name.endswith('.yml') else full_prompt_name + full_prompt_name = ( + f'{full_prompt_name}.yml' + if not full_prompt_name.endswith('.yml') + else full_prompt_name + ) print(f'(status) retrieving template: {url}') - + try: response = requests.get(url) if response.status_code != 200: print(f'(error) error retrieving template - non 200 status code: {response.status_code}') return prompt_data - + prompt_data = yaml.safe_load(response.text) except Exception as err: