trailing-whitespace: add option for custom chars to strip

This commit is contained in:
iconmaster5326
2019-10-25 11:12:49 -04:00
parent 0f7b5e0c4f
commit 886dfc4205

View File

@ -7,10 +7,15 @@ from typing import Optional
from typing import Sequence
def _fix_file(filename, is_markdown): # type: (str, bool) -> bool
def _fix_file(filename, is_markdown, chars_to_strip):
# type: (str, bool, Optional[bytes]) -> bool
with open(filename, mode='rb') as file_processed:
lines = file_processed.readlines()
newlines = [_process_line(line, is_markdown) for line in lines]
newlines = [
_process_line(line, is_markdown, chars_to_strip)
for line
in lines
]
if newlines != lines:
with open(filename, mode='wb') as file_processed:
for line in newlines:
@ -20,7 +25,8 @@ def _fix_file(filename, is_markdown): # type: (str, bool) -> bool
return False
def _process_line(line, is_markdown): # type: (bytes, bool) -> bytes
def _process_line(line, is_markdown, chars_to_strip):
# type: (bytes, bool, Optional[bytes]) -> bytes
if line[-2:] == b'\r\n':
eol = b'\r\n'
elif line[-1:] == b'\n':
@ -29,8 +35,8 @@ def _process_line(line, is_markdown): # type: (bytes, bool) -> bytes
eol = b''
# preserve trailing two-space for non-blank lines in markdown files
if is_markdown and (not line.isspace()) and line.endswith(b' ' + eol):
return line.rstrip() + b' ' + eol
return line.rstrip() + eol
return line.rstrip(chars_to_strip) + b' ' + eol
return line.rstrip(chars_to_strip) + eol
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
@ -50,6 +56,11 @@ def main(argv=None): # type: (Optional[Sequence[str]]) -> int
'default: %(default)s'
),
)
parser.add_argument(
'--chars',
help='The set of characters to strip from the end of lines. '
'Defaults to all whitespace characters.',
)
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
@ -78,7 +89,11 @@ def main(argv=None): # type: (Optional[Sequence[str]]) -> int
for filename in args.filenames:
_, extension = os.path.splitext(filename.lower())
md = all_markdown or extension in md_exts
if _fix_file(filename, md):
if _fix_file(
filename,
md,
None if args.chars is None else bytes(args.chars, 'utf-8'),
):
print('Fixing {}'.format(filename))
return_code = 1
return return_code