diff --git a/util/style.py b/util/style.py index 68f1cdffe..cd9cd9652 100644 --- a/util/style.py +++ b/util/style.py @@ -42,6 +42,7 @@ sys.path.insert(0, current_dir) sys.path.insert(1, joinpath(dirname(current_dir), 'src', 'python')) from m5.util import neg_inf, pos_inf, Region, Regions +import sort_includes from file_types import lang_type all_regions = Region(neg_inf, pos_inf) @@ -184,6 +185,15 @@ class Verifier(object): f.write(line) f.close() + def apply(self, filename, prompt, regions=all_regions): + if not self.skip(filename): + errors = self.check(filename, regions) + if errors: + if prompt(filename, self.fix, regions): + return True + return False + + class Whitespace(Verifier): languages = set(('C', 'C++', 'swig', 'python', 'asm', 'isa', 'scons')) test_name = 'whitespace' @@ -214,6 +224,53 @@ class Whitespace(Verifier): return line.rstrip() + '\n' +class SortedIncludes(Verifier): + languages = sort_includes.default_languages + def __init__(self, *args, **kwargs): + super(SortedIncludes, self).__init__(*args, **kwargs) + self.sort_includes = sort_includes.SortIncludes() + + def check(self, filename, regions=all_regions): + f = self.open(filename, 'r') + + lines = [ l.rstrip('\n') for l in f.xreadlines() ] + old = ''.join(line + '\n' for line in lines) + f.close() + + language = lang_type(filename, lines[0]) + sort_lines = list(self.sort_includes(lines, filename, language)) + new = ''.join(line + '\n' for line in sort_lines) + + mod = modified_regions(old, new) + modified = mod & regions + print mod, regions, modified + + if modified: + self.write("invalid sorting of includes\n") + if self.ui.verbose: + for start, end in modified.regions: + self.write("bad region [%d, %d)\n" % (start, end)) + return 1 + + return 0 + + def fix(self, filename, regions=all_regions): + f = self.open(filename, 'r+') + + old = f.readlines() + lines = [ l.rstrip('\n') for l in old ] + language = lang_type(filename, lines[0]) + sort_lines = list(self.sort_includes(lines, filename, language)) + new = ''.join(line + '\n' for line in sort_lines) + + f.seek(0) + f.truncate() + + for i,line in enumerate(sort_lines): + f.write(line) + f.write('\n') + f.close() + def linelen(line): tabs = line.count('\t') if not tabs: @@ -343,15 +400,16 @@ def do_check_style(hgui, repo, *files, **args): modified, added, removed, deleted, unknown, ignore, clean = repo.status() whitespace = Whitespace(ui) + sorted_includes = SortedIncludes(ui) for fname in added: - if skip(fname) or whitespace.skip(fname): + if skip(fname): continue - errors = whitespace.check(fname) - if errors: - print errors - if prompt(fname, whitespace.fix): - return True + if whitespace.apply(fname, prompt): + return True + + if sorted_includes.apply(fname, prompt): + return True try: wctx = repo.workingctx() @@ -360,15 +418,18 @@ def do_check_style(hgui, repo, *files, **args): wctx = context.workingctx(repo) for fname in modified: - if skip(fname) or whitespace.skip(fname): + if skip(fname): continue regions = modregions(wctx, fname) - errors = whitespace.check(fname, regions) - if errors: - if prompt(fname, whitespace.fix, regions): - return True + if whitespace.apply(fname, prompt, regions): + return True + + if sorted_includes.apply(fname, prompt, regions): + return True + + return False def do_check_format(hgui, repo, **args): ui = MercurialUI(hgui, hgui.verbose, auto)