examples: Fix notebook generation (Fix #3244)

This commit is contained in:
Corentin Le Molgat
2022-04-13 14:32:40 +02:00
parent 0f99350f64
commit 1bee57b277

View File

@@ -3,13 +3,14 @@
import ast
import os
import sys
import re
from nbformat import v3
from nbformat import v4
input_file = sys.argv[1]
print(f'reading {input_file}')
with open(input_file) as fpin:
text = fpin.read()
text = fpin.read()
# Compute output file path.
output_file = input_file
@@ -28,7 +29,7 @@ nbook = v3.reads_py('')
nbook = v4.upgrade(nbook) # Upgrade v3 to v4
print('Adding copyright cell...')
google = '##### Copyright 2021 Google LLC.'
google = '##### Copyright 2022 Google LLC.'
nbook['cells'].append(v4.new_markdown_cell(source=google, id='google'))
print('Adding license cell...')
@@ -66,7 +67,7 @@ link = f'''<table align=\"left\">
</table>'''
nbook['cells'].append(v4.new_markdown_cell(source=link, id='link'))
print('Installing ortools cell...')
print('Adding ortools install cell...')
install_doc = ('First, you must install '
'[ortools](https://pypi.org/project/ortools/) package in this '
'colab.')
@@ -82,27 +83,95 @@ line_start[0] = 0
lines = text.split('\n')
full_text = ''
for c_block, s, e in zip(all_blocks, line_start, line_start[1:] + [len(lines)]):
print(c_block)
c_text = '\n'.join(lines[s:e])
if isinstance(c_block,
ast.If) and c_block.test.comparators[0].s == '__main__':
print('Skip if main', lines[s:e])
elif isinstance(c_block, ast.FunctionDef) and c_block.name == 'main':
# remove start and de-indent lines
c_lines = lines[s + 1:e]
spaces_to_delete = c_block.body[0].col_offset
fixed_lines = [
n_line[spaces_to_delete:]
if n_line.startswith(' ' * spaces_to_delete) else n_line
for n_line in c_lines
]
fixed_text = '\n'.join(fixed_lines)
print('Unwrapping main function')
full_text += fixed_text
else:
print('appending', c_block)
full_text += c_text + '\n'
for idx, (c_block, s, e) in enumerate(
zip(all_blocks, line_start, line_start[1:] + [len(lines)])):
print(f'block[{idx}]: {c_block}')
c_text = '\n'.join(lines[s:e])
# Clean boilerplate header and description
if (idx == 0 and isinstance(c_block, ast.Expr) and
isinstance(c_block.value, ast.Constant)):
print('Adding description cell...')
filtered_lines = lines[s:e]
#filtered_lines = list(filter(lambda l: not l.startswith('#!'), lines[s:e]))
filtered_lines = list(filter(lambda l: not re.search(r'^#!', l), filtered_lines))
filtered_lines = list(filter(lambda l: not re.search(r'# \[START .*\]$', l), filtered_lines))
filtered_lines = list(filter(lambda l: not re.search(r'# \[END .*\]$', l), filtered_lines))
# TODO(mizux): Remove only copyright not all line with '^#'
filtered_lines = list(filter(lambda l: not l.startswith(r'#'), filtered_lines))
filtered_lines = [s.replace(r'"""', '') for s in filtered_lines]
filtered_text = '\n'.join(filtered_lines)
nbook['cells'].append(v4.new_markdown_cell(source=filtered_text, id='description'))
# Remove absl app and flags import
elif (isinstance(c_block, ast.ImportFrom) and c_block.module == 'absl'
and c_block.names[0].name in ('flags', 'app')):
print(f'Removing import {c_block.module}.{c_block.names[0].name}...')
# Rewrite `FLAGS = flags.FLAGS`
elif (isinstance(c_block, ast.Assign) and
isinstance(c_block.targets[0], ast.Name) and
c_block.targets[0].id == 'FLAGS'):
print('Adding FLAGS class...')
fixed_lines = ['class FLAGS: pass\n']
full_text += '\n'.join(fixed_lines) + '\n'
# Rewrite `flags.DEFINE_*(*)`
elif (isinstance(c_block, ast.Expr) and
isinstance(c_block.value, ast.Call) and
isinstance(c_block.value.func, ast.Attribute) and
c_block.value.func.value.id == 'flags'):
print('Adding FLAGS field...')
fixed_lines = []
attr = c_block.value.func.attr
if attr in ('DEFINE_integer', 'DEFINE_bool', 'DEFINE_string'):
args = c_block.value.args
#print(f'args: {args}')
name = args[0].value
if isinstance(args[1], ast.Constant):
value = args[1].value
elif isinstance(args[1], ast.UnaryOp):
if isinstance(args[1].op, ast.USub):
value = -1 * int(args[1].operand.value)
else:
print(f'unknow value operator: "{args[1].op}"')
sys.exit(2)
else:
print(f'unknow value: "{args[1]}"')
sys.exit(2)
comment = args[2].value
print(f'FLAGS.{name} = \'{value}\' # {comment}')
if attr in ('DEFINE_integer', 'DEFINE_bool'):
fixed_lines.append(f'FLAGS.{name} = {value} # {comment}\n')
else:
fixed_lines.append(f'FLAGS.{name} = \'{value}\' # {comment}\n')
else:
print(f'unknow method: "{attr}"')
sys.exit(2)
full_text += '\n'.join(fixed_lines)
# Add empty line after the last flags.DEFINE
if e-2 >= s and lines[e-1] == '' and lines[e-2] == '':
full_text += '\n'
# Unwrap __main__ function
elif (isinstance(c_block, ast.If) and c_block.test.comparators[0].s == '__main__'):
print('Unwrapping main function...')
c_lines = lines[s + 1:e]
# remove start and de-indent lines
spaces_to_delete = c_block.body[0].col_offset
fixed_lines = [
n_line[spaces_to_delete:]
if n_line.startswith(' ' * spaces_to_delete) else n_line
for n_line in c_lines
]
filtered_lines = fixed_lines
filtered_lines = list(filter(lambda l: not re.search(r'# \[START .*\]$', l), filtered_lines))
filtered_lines = list(filter(lambda l: not re.search(r'# \[END .*\]$', l), filtered_lines))
filtered_lines = [re.sub(r'app.run\((.*)\)$', r'\1()', s) for s in filtered_lines]
full_text += '\n'.join(filtered_lines) + '\n'
# Others
else:
print('Appending block...')
filtered_lines = lines[s:e]
filtered_lines = list(filter(lambda l: not re.search(r'# \[START .*\]$', l), filtered_lines))
filtered_lines = list(filter(lambda l: not re.search(r'# \[END .*\]$', l), filtered_lines))
full_text += '\n'.join(filtered_lines) + '\n'
nbook['cells'].append(v4.new_code_cell(source=full_text, id='code'))
@@ -110,4 +179,4 @@ jsonform = v4.writes(nbook) + '\n'
print(f'writing {output_file}')
with open(output_file, 'w') as fpout:
fpout.write(jsonform)
fpout.write(jsonform)