diff --git a/tools/export_to_ipynb.py b/tools/export_to_ipynb.py
index 3244cf2a64..88a18f064b 100755
--- a/tools/export_to_ipynb.py
+++ b/tools/export_to_ipynb.py
@@ -3,13 +3,14 @@
import ast
import os
import sys
+import re
from nbformat import v3
from nbformat import v4
input_file = sys.argv[1]
print(f'reading {input_file}')
with open(input_file) as fpin:
- text = fpin.read()
+ text = fpin.read()
# Compute output file path.
output_file = input_file
@@ -28,7 +29,7 @@ nbook = v3.reads_py('')
nbook = v4.upgrade(nbook) # Upgrade v3 to v4
print('Adding copyright cell...')
-google = '##### Copyright 2021 Google LLC.'
+google = '##### Copyright 2022 Google LLC.'
nbook['cells'].append(v4.new_markdown_cell(source=google, id='google'))
print('Adding license cell...')
@@ -66,7 +67,7 @@ link = f'''
'''
nbook['cells'].append(v4.new_markdown_cell(source=link, id='link'))
-print('Installing ortools cell...')
+print('Adding ortools install cell...')
install_doc = ('First, you must install '
'[ortools](https://pypi.org/project/ortools/) package in this '
'colab.')
@@ -82,27 +83,95 @@ line_start[0] = 0
lines = text.split('\n')
full_text = ''
-for c_block, s, e in zip(all_blocks, line_start, line_start[1:] + [len(lines)]):
- print(c_block)
- c_text = '\n'.join(lines[s:e])
- if isinstance(c_block,
- ast.If) and c_block.test.comparators[0].s == '__main__':
- print('Skip if main', lines[s:e])
- elif isinstance(c_block, ast.FunctionDef) and c_block.name == 'main':
- # remove start and de-indent lines
- c_lines = lines[s + 1:e]
- spaces_to_delete = c_block.body[0].col_offset
- fixed_lines = [
- n_line[spaces_to_delete:]
- if n_line.startswith(' ' * spaces_to_delete) else n_line
- for n_line in c_lines
- ]
- fixed_text = '\n'.join(fixed_lines)
- print('Unwrapping main function')
- full_text += fixed_text
- else:
- print('appending', c_block)
- full_text += c_text + '\n'
+for idx, (c_block, s, e) in enumerate(
+ zip(all_blocks, line_start, line_start[1:] + [len(lines)])):
+ print(f'block[{idx}]: {c_block}')
+ c_text = '\n'.join(lines[s:e])
+ # Clean boilerplate header and description
+ if (idx == 0 and isinstance(c_block, ast.Expr) and
+ isinstance(c_block.value, ast.Constant)):
+ print('Adding description cell...')
+ filtered_lines = lines[s:e]
+ #filtered_lines = list(filter(lambda l: not l.startswith('#!'), lines[s:e]))
+ filtered_lines = list(filter(lambda l: not re.search(r'^#!', l), filtered_lines))
+ filtered_lines = list(filter(lambda l: not re.search(r'# \[START .*\]$', l), filtered_lines))
+ filtered_lines = list(filter(lambda l: not re.search(r'# \[END .*\]$', l), filtered_lines))
+ # TODO(mizux): Remove only copyright not all line with '^#'
+ filtered_lines = list(filter(lambda l: not l.startswith(r'#'), filtered_lines))
+ filtered_lines = [s.replace(r'"""', '') for s in filtered_lines]
+ filtered_text = '\n'.join(filtered_lines)
+ nbook['cells'].append(v4.new_markdown_cell(source=filtered_text, id='description'))
+ # Remove absl app and flags import
+ elif (isinstance(c_block, ast.ImportFrom) and c_block.module == 'absl'
+ and c_block.names[0].name in ('flags', 'app')):
+ print(f'Removing import {c_block.module}.{c_block.names[0].name}...')
+ # Rewrite `FLAGS = flags.FLAGS`
+ elif (isinstance(c_block, ast.Assign) and
+ isinstance(c_block.targets[0], ast.Name) and
+ c_block.targets[0].id == 'FLAGS'):
+ print('Adding FLAGS class...')
+ fixed_lines = ['class FLAGS: pass\n']
+ full_text += '\n'.join(fixed_lines) + '\n'
+ # Rewrite `flags.DEFINE_*(*)`
+ elif (isinstance(c_block, ast.Expr) and
+ isinstance(c_block.value, ast.Call) and
+ isinstance(c_block.value.func, ast.Attribute) and
+ c_block.value.func.value.id == 'flags'):
+ print('Adding FLAGS field...')
+ fixed_lines = []
+ attr = c_block.value.func.attr
+ if attr in ('DEFINE_integer', 'DEFINE_bool', 'DEFINE_string'):
+ args = c_block.value.args
+ #print(f'args: {args}')
+ name = args[0].value
+ if isinstance(args[1], ast.Constant):
+ value = args[1].value
+ elif isinstance(args[1], ast.UnaryOp):
+ if isinstance(args[1].op, ast.USub):
+ value = -1 * int(args[1].operand.value)
+ else:
+ print(f'unknow value operator: "{args[1].op}"')
+ sys.exit(2)
+ else:
+ print(f'unknow value: "{args[1]}"')
+ sys.exit(2)
+ comment = args[2].value
+
+ print(f'FLAGS.{name} = \'{value}\' # {comment}')
+ if attr in ('DEFINE_integer', 'DEFINE_bool'):
+ fixed_lines.append(f'FLAGS.{name} = {value} # {comment}\n')
+ else:
+ fixed_lines.append(f'FLAGS.{name} = \'{value}\' # {comment}\n')
+ else:
+ print(f'unknow method: "{attr}"')
+ sys.exit(2)
+ full_text += '\n'.join(fixed_lines)
+ # Add empty line after the last flags.DEFINE
+ if e-2 >= s and lines[e-1] == '' and lines[e-2] == '':
+ full_text += '\n'
+ # Unwrap __main__ function
+ elif (isinstance(c_block, ast.If) and c_block.test.comparators[0].s == '__main__'):
+ print('Unwrapping main function...')
+ c_lines = lines[s + 1:e]
+ # remove start and de-indent lines
+ spaces_to_delete = c_block.body[0].col_offset
+ fixed_lines = [
+ n_line[spaces_to_delete:]
+ if n_line.startswith(' ' * spaces_to_delete) else n_line
+ for n_line in c_lines
+ ]
+ filtered_lines = fixed_lines
+ filtered_lines = list(filter(lambda l: not re.search(r'# \[START .*\]$', l), filtered_lines))
+ filtered_lines = list(filter(lambda l: not re.search(r'# \[END .*\]$', l), filtered_lines))
+ filtered_lines = [re.sub(r'app.run\((.*)\)$', r'\1()', s) for s in filtered_lines]
+ full_text += '\n'.join(filtered_lines) + '\n'
+ # Others
+ else:
+ print('Appending block...')
+ filtered_lines = lines[s:e]
+ filtered_lines = list(filter(lambda l: not re.search(r'# \[START .*\]$', l), filtered_lines))
+ filtered_lines = list(filter(lambda l: not re.search(r'# \[END .*\]$', l), filtered_lines))
+ full_text += '\n'.join(filtered_lines) + '\n'
nbook['cells'].append(v4.new_code_cell(source=full_text, id='code'))
@@ -110,4 +179,4 @@ jsonform = v4.writes(nbook) + '\n'
print(f'writing {output_file}')
with open(output_file, 'w') as fpout:
- fpout.write(jsonform)
+ fpout.write(jsonform)