diff --git a/pyannotate_tools/annotations/__main__.py b/pyannotate_tools/annotations/__main__.py index 33bde4f..32223bf 100644 --- a/pyannotate_tools/annotations/__main__.py +++ b/pyannotate_tools/annotations/__main__.py @@ -12,10 +12,13 @@ from pyannotate_tools.annotations.main import generate_annotations_json_string, unify_type_comments from pyannotate_tools.fixes.fix_annotate_json import FixAnnotateJson +from pyannotate_tools.fixes.fix_annotate_command import FixAnnotateCommand parser = argparse.ArgumentParser() parser.add_argument('--type-info', default='type_info.json', metavar="FILE", help="JSON input file (default type_info.json)") +parser.add_argument('--command', '-c', metavar="COMMAND", + help="Command to call to generate JSON info for a call site") parser.add_argument('--uses-signature', action='store_true', help="JSON input uses a signature format") parser.add_argument('-p', '--print-function', action='store_true', @@ -111,24 +114,30 @@ def main(args_override=None): if args.auto_any: fixers = ['pyannotate_tools.fixes.fix_annotate'] else: - # Produce nice error message if type_info.json not found. - try: - with open(args.type_info) as f: - contents = f.read() - except IOError as err: - sys.exit("Can't open type info file: %s" % err) - - # Run pass 2 with output into a variable. - if args.uses_signature: - data = json.loads(contents) # type: List[Any] - else: - data = generate_annotations_json_string( - args.type_info, - only_simple=args.only_simple) + fixers = [] + if args.type_info: + # Produce nice error message if type_info.json not found. + try: + with open(args.type_info) as f: + contents = f.read() + except IOError as err: + sys.exit("Can't open type info file: %s" % err) + + # Run pass 2 with output into a variable. + if args.uses_signature: + data = json.loads(contents) # type: List[Any] + else: + data = generate_annotations_json_string( + args.type_info, + only_simple=args.only_simple) + + # Run pass 3 with input from that variable. + FixAnnotateJson.init_stub_json_from_data(data, args.files[0]) + fixers.append('pyannotate_tools.fixes.fix_annotate_json') - # Run pass 3 with input from that variable. - FixAnnotateJson.init_stub_json_from_data(data, args.files[0]) - fixers = ['pyannotate_tools.fixes.fix_annotate_json'] + if args.command: + fixers.append('pyannotate_tools.fixes.fix_annotate_command') + FixAnnotateCommand.set_command(args.command) flags = {'print_function': args.print_function, 'annotation_style': annotation_style} diff --git a/pyannotate_tools/fixes/fix_annotate.py b/pyannotate_tools/fixes/fix_annotate.py index b3eeac8..a141b24 100644 --- a/pyannotate_tools/fixes/fix_annotate.py +++ b/pyannotate_tools/fixes/fix_annotate.py @@ -42,7 +42,7 @@ def foo(self, bar : Any, baz : int = 12) -> Any: from lib2to3.pytree import Leaf, Node -class FixAnnotate(BaseFix): +class BaseFixAnnotate(BaseFix): # This fixer is compatible with the bottom matcher. BM_compatible = True @@ -59,8 +59,8 @@ class FixAnnotate(BaseFix): counter = None if not _maxfixes else int(_maxfixes) def transform(self, node, results): - if FixAnnotate.counter is not None: - if FixAnnotate.counter <= 0: + if BaseFixAnnotate.counter is not None: + if BaseFixAnnotate.counter <= 0: return # Check if there's already a long-form annotation for some argument. @@ -181,8 +181,8 @@ def transform(self, node, results): self.add_py2_annot(argtypes, restype, node, results) # Common to py2 and py3 style annotations: - if FixAnnotate.counter is not None: - FixAnnotate.counter -= 1 + if BaseFixAnnotate.counter is not None: + BaseFixAnnotate.counter -= 1 # Also add 'from typing import Any' at the top if needed. self.patch_imports(argtypes + [restype], node) @@ -348,67 +348,7 @@ def patch_imports(self, types, node): break def make_annotation(self, node, results): - name = results['name'] - assert isinstance(name, Leaf), repr(name) - assert name.type == token.NAME, repr(name) - decorators = self.get_decorators(node) - is_method = self.is_method(node) - if name.value == '__init__' or not self.has_return_exprs(node): - restype = 'None' - else: - restype = 'Any' - args = results.get('args') - argtypes = [] - if isinstance(args, Node): - children = args.children - elif isinstance(args, Leaf): - children = [args] - else: - children = [] - # Interpret children according to the following grammar: - # (('*'|'**')? NAME ['=' expr] ','?)* - stars = inferred_type = '' - in_default = False - at_start = True - for child in children: - if isinstance(child, Leaf): - if child.value in ('*', '**'): - stars += child.value - elif child.type == token.NAME and not in_default: - if not is_method or not at_start or 'staticmethod' in decorators: - inferred_type = 'Any' - else: - # Always skip the first argument if it's named 'self'. - # Always skip the first argument of a class method. - if child.value == 'self' or 'classmethod' in decorators: - pass - else: - inferred_type = 'Any' - elif child.value == '=': - in_default = True - elif in_default and child.value != ',': - if child.type == token.NUMBER: - if re.match(r'\d+[lL]?$', child.value): - inferred_type = 'int' - else: - inferred_type = 'float' # TODO: complex? - elif child.type == token.STRING: - if child.value.startswith(('u', 'U')): - inferred_type = 'unicode' - else: - inferred_type = 'str' - elif child.type == token.NAME and child.value in ('True', 'False'): - inferred_type = 'bool' - elif child.value == ',': - if inferred_type: - argtypes.append(stars + inferred_type) - # Reset - stars = inferred_type = '' - in_default = False - at_start = False - if inferred_type: - argtypes.append(stars + inferred_type) - return argtypes, restype + raise NotImplementedError # The parse tree has a different shape when there is a single # decorator vs. when there are multiple decorators. @@ -479,3 +419,69 @@ def is_generator(self, node): if self.is_generator(child): return True return False + + +class FixAnnotate(BaseFixAnnotate): + + def make_annotation(self, node, results): + name = results['name'] + assert isinstance(name, Leaf), repr(name) + assert name.type == token.NAME, repr(name) + decorators = self.get_decorators(node) + is_method = self.is_method(node) + if name.value == '__init__' or not self.has_return_exprs(node): + restype = 'None' + else: + restype = 'Any' + args = results.get('args') + argtypes = [] + if isinstance(args, Node): + children = args.children + elif isinstance(args, Leaf): + children = [args] + else: + children = [] + # Interpret children according to the following grammar: + # (('*'|'**')? NAME ['=' expr] ','?)* + stars = inferred_type = '' + in_default = False + at_start = True + for child in children: + if isinstance(child, Leaf): + if child.value in ('*', '**'): + stars += child.value + elif child.type == token.NAME and not in_default: + if not is_method or not at_start or 'staticmethod' in decorators: + inferred_type = 'Any' + else: + # Always skip the first argument if it's named 'self'. + # Always skip the first argument of a class method. + if child.value == 'self' or 'classmethod' in decorators: + pass + else: + inferred_type = 'Any' + elif child.value == '=': + in_default = True + elif in_default and child.value != ',': + if child.type == token.NUMBER: + if re.match(r'\d+[lL]?$', child.value): + inferred_type = 'int' + else: + inferred_type = 'float' # TODO: complex? + elif child.type == token.STRING: + if child.value.startswith(('u', 'U')): + inferred_type = 'unicode' + else: + inferred_type = 'str' + elif child.type == token.NAME and child.value in ('True', 'False'): + inferred_type = 'bool' + elif child.value == ',': + if inferred_type: + argtypes.append(stars + inferred_type) + # Reset + stars = inferred_type = '' + in_default = False + at_start = False + if inferred_type: + argtypes.append(stars + inferred_type) + return argtypes, restype diff --git a/pyannotate_tools/fixes/fix_annotate_command.py b/pyannotate_tools/fixes/fix_annotate_command.py new file mode 100644 index 0000000..ab907b5 --- /dev/null +++ b/pyannotate_tools/fixes/fix_annotate_command.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import, print_function + +import json +import shlex +import subprocess + +from .fix_annotate_json import BaseFixAnnotateFromSignature, FixAnnotateJson as _FixAnnotateJson + +class FixAnnotateCommand(BaseFixAnnotateFromSignature): + # run after FixAnnotateJson + run_order = _FixAnnotateJson.run_order + 1 + + command = None + + @classmethod + def set_command(cls, command): + cls.command = command + + def get_command(self, filename, lineno): + return shlex.split(self.command.format(filename=filename, lineno=lineno)) + + def get_types(self, node, results, funcname): + cmd = self.get_command(self.filename, node.get_lineno()) + try: + out = subprocess.check_output(cmd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as err: + self.log_message("Line %d: Failed calling `%s`: %s" % + (node.get_lineno(), self.command, + err.output.rstrip())) + return None + + data = json.loads(out) + signature = data[0]['signature'] + return signature['arg_types'], signature['return_type'] diff --git a/pyannotate_tools/fixes/fix_annotate_json.py b/pyannotate_tools/fixes/fix_annotate_json.py index a24cd68..afa91d5 100644 --- a/pyannotate_tools/fixes/fix_annotate_json.py +++ b/pyannotate_tools/fixes/fix_annotate_json.py @@ -34,7 +34,7 @@ # In Python 3.5.1 stdlib, typing.py does not define Text Text = str # type: ignore -from .fix_annotate import FixAnnotate +from .fix_annotate import BaseFixAnnotate # Taken from mypy codebase: # https://github.com/python/mypy/blob/745d300b8304c3dcf601477762bf9d70b9a4619c/mypy/main.py#L503 @@ -151,10 +151,10 @@ def count_args(node, results): previous_token_is_star = False return count, selfish, star, starstar -class FixAnnotateJson(FixAnnotate): + +class BaseFixAnnotateFromSignature(BaseFixAnnotate): needed_imports = None - line_drift = 5 def add_import(self, mod, name): if mod == self.current_module(): @@ -170,19 +170,29 @@ def patch_imports(self, types, node): self.needed_imports = None def set_filename(self, filename): - super(FixAnnotateJson, self).set_filename(filename) + super(BaseFixAnnotateFromSignature, self).set_filename(filename) self._current_module = crawl_up(filename)[1] def current_module(self): return self._current_module + def get_types(self, node, results, funcname): + raise NotImplementedError + def make_annotation(self, node, results): name = results['name'] assert isinstance(name, Leaf), repr(name) assert name.type == token.NAME, repr(name) funcname = get_funcname(node) - res = self.get_annotation_from_stub(node, results, funcname) + def make(node, results, funcname): + sig_data = self.get_types(node, results, funcname) + if sig_data: + arg_types, ret_type = sig_data + return self.process_types(node, results, arg_types, ret_type) + return None + + res = make(node, results, funcname) # If we couldn't find an annotation and this is a classmethod or # staticmethod, try again with just the funcname, since the # type collector can't figure out class names for those. @@ -191,13 +201,86 @@ def make_annotation(self, node, results): if not res: decs = self.get_decorators(node) if 'staticmethod' in decs or 'classmethod' in decs: - res = self.get_annotation_from_stub(node, results, name.value) - + res = make(node, results, name.value) return res + def process_types(self, node, results, arg_types, ret_type): + # Passes 1-2 don't always understand *args or **kwds, + # so add '*Any' or '**Any' at the end if needed. + count, selfish, star, starstar = count_args(node, results) + for arg_type in arg_types: + if arg_type.startswith('**'): + starstar = False + elif arg_type.startswith('*'): + star = False + if star: + arg_types.append('*Any') + if starstar: + arg_types.append('**Any') + # Pass 1 omits the first arg iff it's named 'self' or 'cls', + # even if it's not a method, so insert `Any` as needed + # (but only if it's not actually a method). + if selfish and len(arg_types) == count - 1: + if self.is_method(node): + count -= 1 # Leave out the type for 'self' or 'cls' + else: + arg_types.insert(0, 'Any') + # If after those adjustments the count is still off, + # print a warning and skip this node. + if len(arg_types) != count: + self.log_message("%s:%d: source has %d args, annotation has %d -- skipping" % + (self.filename, node.get_lineno(), count, len(arg_types))) + return None + + arg_types = [self.update_type_names(arg_type) for arg_type in arg_types] + # Avoid common error "No return value expected" + if ret_type == 'None' and self.has_return_exprs(node): + ret_type = 'Optional[Any]' + # Special case for generators. + if (self.is_generator(node) and + not (ret_type == 'Iterator' or ret_type.startswith('Iterator['))): + if ret_type.startswith('Optional['): + assert ret_type[-1] == ']' + ret_type = ret_type[9:-1] + ret_type = 'Iterator[%s]' % ret_type + ret_type = self.update_type_names(ret_type) + return arg_types, ret_type + + def update_type_names(self, type_str): + # Replace e.g. `List[pkg.mod.SomeClass]` with + # `List[SomeClass]` and remember to import it. + return re.sub(r'[\w.:]+', self.type_updater, type_str) + + def type_updater(self, match): + # Replace `pkg.mod.SomeClass` with `SomeClass` + # and remember to import it. + word = match.group() + if word == '...': + return word + if '.' not in word and ':' not in word: + # Assume it's either builtin or from `typing` + if word in typing_all: + self.add_import('typing', word) + return word + # If there is a :, treat that as the separator between the + # module and the class. Otherwise assume everything but the + # last element is the module. + if ':' in word: + mod, name = word.split(':') + to_import = name.split('.', 1)[0] + else: + mod, name = word.rsplit('.', 1) + to_import = name + self.add_import(mod, to_import) + return name + + +class FixAnnotateJson(BaseFixAnnotateFromSignature): + stub_json_file = os.getenv('TYPE_COLLECTION_JSON') # JSON data for the current file stub_json = None # type: List[Dict[str, Any]] + line_drift = 5 @classmethod @contextmanager @@ -219,8 +302,8 @@ def init_stub_json(self): data = json.load(f) self.__class__.init_stub_json_from_data(data, self.filename) - def get_annotation_from_stub(self, node, results, funcname): - if not self.__class__.stub_json: + def get_types(self, node, results, funcname): + if self.__class__.stub_json is None: self.init_stub_json() data = self.__class__.stub_json # We are using relative paths in the JSON. @@ -250,74 +333,5 @@ def get_annotation_from_stub(self, node, results, funcname): (self.filename, node.get_lineno(), it['func_name'], it['line'])) return None if 'signature' in it: - sig = it['signature'] - arg_types = sig['arg_types'] - # Passes 1-2 don't always understand *args or **kwds, - # so add '*Any' or '**Any' at the end if needed. - count, selfish, star, starstar = count_args(node, results) - for arg_type in arg_types: - if arg_type.startswith('**'): - starstar = False - elif arg_type.startswith('*'): - star = False - if star: - arg_types.append('*Any') - if starstar: - arg_types.append('**Any') - # Pass 1 omits the first arg iff it's named 'self' or 'cls', - # even if it's not a method, so insert `Any` as needed - # (but only if it's not actually a method). - if selfish and len(arg_types) == count - 1: - if self.is_method(node): - count -= 1 # Leave out the type for 'self' or 'cls' - else: - arg_types.insert(0, 'Any') - # If after those adjustments the count is still off, - # print a warning and skip this node. - if len(arg_types) != count: - self.log_message("%s:%d: source has %d args, annotation has %d -- skipping" % - (self.filename, node.get_lineno(), count, len(arg_types))) - return None - ret_type = sig['return_type'] - arg_types = [self.update_type_names(arg_type) for arg_type in arg_types] - # Avoid common error "No return value expected" - if ret_type == 'None' and self.has_return_exprs(node): - ret_type = 'Optional[Any]' - # Special case for generators. - if (self.is_generator(node) and - not (ret_type == 'Iterator' or ret_type.startswith('Iterator['))): - if ret_type.startswith('Optional['): - assert ret_type[-1] == ']' - ret_type = ret_type[9:-1] - ret_type = 'Iterator[%s]' % ret_type - ret_type = self.update_type_names(ret_type) - return arg_types, ret_type + return it['signature']['arg_types'], it['signature']['return_type'] return None - - def update_type_names(self, type_str): - # Replace e.g. `List[pkg.mod.SomeClass]` with - # `List[SomeClass]` and remember to import it. - return re.sub(r'[\w.:]+', self.type_updater, type_str) - - def type_updater(self, match): - # Replace `pkg.mod.SomeClass` with `SomeClass` - # and remember to import it. - word = match.group() - if word == '...': - return word - if '.' not in word and ':' not in word: - # Assume it's either builtin or from `typing` - if word in typing_all: - self.add_import('typing', word) - return word - # If there is a :, treat that as the separator between the - # module and the class. Otherwise assume everything but the - # last element is the module. - if ':' in word: - mod, name = word.split(':') - to_import = name.split('.', 1)[0] - else: - mod, name = word.rsplit('.', 1) - to_import = name - self.add_import(mod, to_import) - return name