examples: Fix notebook generation (Fix #3244)
This commit is contained in:
@@ -3,13 +3,14 @@
|
||||
import ast
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
from nbformat import v3
|
||||
from nbformat import v4
|
||||
|
||||
input_file = sys.argv[1]
|
||||
print(f'reading {input_file}')
|
||||
with open(input_file) as fpin:
|
||||
text = fpin.read()
|
||||
text = fpin.read()
|
||||
|
||||
# Compute output file path.
|
||||
output_file = input_file
|
||||
@@ -28,7 +29,7 @@ nbook = v3.reads_py('')
|
||||
nbook = v4.upgrade(nbook) # Upgrade v3 to v4
|
||||
|
||||
print('Adding copyright cell...')
|
||||
google = '##### Copyright 2021 Google LLC.'
|
||||
google = '##### Copyright 2022 Google LLC.'
|
||||
nbook['cells'].append(v4.new_markdown_cell(source=google, id='google'))
|
||||
|
||||
print('Adding license cell...')
|
||||
@@ -66,7 +67,7 @@ link = f'''<table align=\"left\">
|
||||
</table>'''
|
||||
nbook['cells'].append(v4.new_markdown_cell(source=link, id='link'))
|
||||
|
||||
print('Installing ortools cell...')
|
||||
print('Adding ortools install cell...')
|
||||
install_doc = ('First, you must install '
|
||||
'[ortools](https://pypi.org/project/ortools/) package in this '
|
||||
'colab.')
|
||||
@@ -82,27 +83,95 @@ line_start[0] = 0
|
||||
lines = text.split('\n')
|
||||
|
||||
full_text = ''
|
||||
for c_block, s, e in zip(all_blocks, line_start, line_start[1:] + [len(lines)]):
|
||||
print(c_block)
|
||||
c_text = '\n'.join(lines[s:e])
|
||||
if isinstance(c_block,
|
||||
ast.If) and c_block.test.comparators[0].s == '__main__':
|
||||
print('Skip if main', lines[s:e])
|
||||
elif isinstance(c_block, ast.FunctionDef) and c_block.name == 'main':
|
||||
# remove start and de-indent lines
|
||||
c_lines = lines[s + 1:e]
|
||||
spaces_to_delete = c_block.body[0].col_offset
|
||||
fixed_lines = [
|
||||
n_line[spaces_to_delete:]
|
||||
if n_line.startswith(' ' * spaces_to_delete) else n_line
|
||||
for n_line in c_lines
|
||||
]
|
||||
fixed_text = '\n'.join(fixed_lines)
|
||||
print('Unwrapping main function')
|
||||
full_text += fixed_text
|
||||
else:
|
||||
print('appending', c_block)
|
||||
full_text += c_text + '\n'
|
||||
for idx, (c_block, s, e) in enumerate(
|
||||
zip(all_blocks, line_start, line_start[1:] + [len(lines)])):
|
||||
print(f'block[{idx}]: {c_block}')
|
||||
c_text = '\n'.join(lines[s:e])
|
||||
# Clean boilerplate header and description
|
||||
if (idx == 0 and isinstance(c_block, ast.Expr) and
|
||||
isinstance(c_block.value, ast.Constant)):
|
||||
print('Adding description cell...')
|
||||
filtered_lines = lines[s:e]
|
||||
#filtered_lines = list(filter(lambda l: not l.startswith('#!'), lines[s:e]))
|
||||
filtered_lines = list(filter(lambda l: not re.search(r'^#!', l), filtered_lines))
|
||||
filtered_lines = list(filter(lambda l: not re.search(r'# \[START .*\]$', l), filtered_lines))
|
||||
filtered_lines = list(filter(lambda l: not re.search(r'# \[END .*\]$', l), filtered_lines))
|
||||
# TODO(mizux): Remove only copyright not all line with '^#'
|
||||
filtered_lines = list(filter(lambda l: not l.startswith(r'#'), filtered_lines))
|
||||
filtered_lines = [s.replace(r'"""', '') for s in filtered_lines]
|
||||
filtered_text = '\n'.join(filtered_lines)
|
||||
nbook['cells'].append(v4.new_markdown_cell(source=filtered_text, id='description'))
|
||||
# Remove absl app and flags import
|
||||
elif (isinstance(c_block, ast.ImportFrom) and c_block.module == 'absl'
|
||||
and c_block.names[0].name in ('flags', 'app')):
|
||||
print(f'Removing import {c_block.module}.{c_block.names[0].name}...')
|
||||
# Rewrite `FLAGS = flags.FLAGS`
|
||||
elif (isinstance(c_block, ast.Assign) and
|
||||
isinstance(c_block.targets[0], ast.Name) and
|
||||
c_block.targets[0].id == 'FLAGS'):
|
||||
print('Adding FLAGS class...')
|
||||
fixed_lines = ['class FLAGS: pass\n']
|
||||
full_text += '\n'.join(fixed_lines) + '\n'
|
||||
# Rewrite `flags.DEFINE_*(*)`
|
||||
elif (isinstance(c_block, ast.Expr) and
|
||||
isinstance(c_block.value, ast.Call) and
|
||||
isinstance(c_block.value.func, ast.Attribute) and
|
||||
c_block.value.func.value.id == 'flags'):
|
||||
print('Adding FLAGS field...')
|
||||
fixed_lines = []
|
||||
attr = c_block.value.func.attr
|
||||
if attr in ('DEFINE_integer', 'DEFINE_bool', 'DEFINE_string'):
|
||||
args = c_block.value.args
|
||||
#print(f'args: {args}')
|
||||
name = args[0].value
|
||||
if isinstance(args[1], ast.Constant):
|
||||
value = args[1].value
|
||||
elif isinstance(args[1], ast.UnaryOp):
|
||||
if isinstance(args[1].op, ast.USub):
|
||||
value = -1 * int(args[1].operand.value)
|
||||
else:
|
||||
print(f'unknow value operator: "{args[1].op}"')
|
||||
sys.exit(2)
|
||||
else:
|
||||
print(f'unknow value: "{args[1]}"')
|
||||
sys.exit(2)
|
||||
comment = args[2].value
|
||||
|
||||
print(f'FLAGS.{name} = \'{value}\' # {comment}')
|
||||
if attr in ('DEFINE_integer', 'DEFINE_bool'):
|
||||
fixed_lines.append(f'FLAGS.{name} = {value} # {comment}\n')
|
||||
else:
|
||||
fixed_lines.append(f'FLAGS.{name} = \'{value}\' # {comment}\n')
|
||||
else:
|
||||
print(f'unknow method: "{attr}"')
|
||||
sys.exit(2)
|
||||
full_text += '\n'.join(fixed_lines)
|
||||
# Add empty line after the last flags.DEFINE
|
||||
if e-2 >= s and lines[e-1] == '' and lines[e-2] == '':
|
||||
full_text += '\n'
|
||||
# Unwrap __main__ function
|
||||
elif (isinstance(c_block, ast.If) and c_block.test.comparators[0].s == '__main__'):
|
||||
print('Unwrapping main function...')
|
||||
c_lines = lines[s + 1:e]
|
||||
# remove start and de-indent lines
|
||||
spaces_to_delete = c_block.body[0].col_offset
|
||||
fixed_lines = [
|
||||
n_line[spaces_to_delete:]
|
||||
if n_line.startswith(' ' * spaces_to_delete) else n_line
|
||||
for n_line in c_lines
|
||||
]
|
||||
filtered_lines = fixed_lines
|
||||
filtered_lines = list(filter(lambda l: not re.search(r'# \[START .*\]$', l), filtered_lines))
|
||||
filtered_lines = list(filter(lambda l: not re.search(r'# \[END .*\]$', l), filtered_lines))
|
||||
filtered_lines = [re.sub(r'app.run\((.*)\)$', r'\1()', s) for s in filtered_lines]
|
||||
full_text += '\n'.join(filtered_lines) + '\n'
|
||||
# Others
|
||||
else:
|
||||
print('Appending block...')
|
||||
filtered_lines = lines[s:e]
|
||||
filtered_lines = list(filter(lambda l: not re.search(r'# \[START .*\]$', l), filtered_lines))
|
||||
filtered_lines = list(filter(lambda l: not re.search(r'# \[END .*\]$', l), filtered_lines))
|
||||
full_text += '\n'.join(filtered_lines) + '\n'
|
||||
|
||||
nbook['cells'].append(v4.new_code_cell(source=full_text, id='code'))
|
||||
|
||||
@@ -110,4 +179,4 @@ jsonform = v4.writes(nbook) + '\n'
|
||||
|
||||
print(f'writing {output_file}')
|
||||
with open(output_file, 'w') as fpout:
|
||||
fpout.write(jsonform)
|
||||
fpout.write(jsonform)
|
||||
|
||||
Reference in New Issue
Block a user