在阅读算法文献或者数学相关的文章中经常会看到一些简单或复杂的数学公式,最近在分享此类文章时,想使用LaTex键入数学公式以美化阅读,发现需要反复去查询LaTex相关的语法,效率较低且容易出错。

最近 GitHub 上出现了一个开源项目 latexify_py,它使用 Python 就能生成 LaTeX 数学公式。打开Google Colaboratory示例列举了几个案例:

image-20200910103526562

先试试看

在本地安装相应的Python包,Python版本 >= 3.6

pip install latexify-py

参考官方示例进行测试:

import math
import latexify

@latexify.with_latex
def solve(a, b, c):
    return (-b + math.sqrt(b ** 2 - 4 * a * c)) / (2 * a)

if __name__ == '__main__':
    print(solve)

终端打印结果为:

\mathrm{solve}(a, b, c)\triangleq \frac{-b + \sqrt{b^{2} - 4ac}}{2a}

将打印结果输入到支持LaTeX的编辑器中,以Typora为例。选择插入公式块:

\mathrm{solve}(a, b, c)\triangleq \frac{-b + \sqrt{b^{2} - 4ac}}{2a}

于是,把最近阅读的facebook开源的prophet时间序列预测算法提到的饱和增长模型公式进行测试,原文中为

image-20200910115235060

开始在python中键入代码:

@latexify.with_latex
def g(t):
    return C(t) / (1 + exp(1-(k + alpha(t) ** T * delta) * (t -(m + alpha(t) ** T * gamma))))

终端打印结果并输入Typora为:

\mathrm{g}(t)\triangleq \frac{\mathrm{C}\left(t\right)}{1 + \mathrm{exp}\left(1 - (k + \mathrm{{\alpha}}\left(t\right)^{t}{\delta})(t - m + \mathrm{{\alpha}}\left(t\right)^{T}{\gamma})\right)}
\mathrm{g}(t)\triangleq \frac{\mathrm{C}\left(t\right)}{1 + \mathrm{exp}\left(1 - (k + \mathrm{{\alpha}}\left(t\right)^{T}{\delta})(t - m + \mathrm{{\alpha}}\left(t\right)^{T}{\gamma})\right)}

对比发现python输出的公式中有一个错误:删除了一个括号,而python代码中是包含的,由

t - (m + \mathrm{{\alpha}}\left(t\right)^{T}{\gamma})

变成了:

t - m + \mathrm{{\alpha}}\left(t\right)^{T}{\gamma}

为了进一步验证上面出现的问题,输入一段很简单的代码:

@latexify.with_latex
def test(a, b):
    return  - (a + b)

输出的公式和预想的一致:

\mathrm{test}(a, b)\triangleq -\left(a + b\right)

这时,小小的修改一下代码:

@latexify.with_latex
def test(a, b):
    return  1 - (a + b)

预想的公式应该为:

\mathrm{test}(a, b)\triangleq 1 - (a + b)

而实际却是:

\mathrm{test}(a, b)\triangleq 1 - a + b

猜想,这可能是一个bug或者是输入的方式不对,虽然这个问题很好解决,但是一直很疑惑。。。。。

latexify_py做了什么?

为了一探究竟,尝试去阅读其源码,看看它都做了哪些事情?

首先入口是@latexify.with_latex这个注解。latexify提供with_latex和get_latex两个注解,with_latex只是先做一些初始化,实际也是调用get_latex。重点看一下get_latex,其源码:

def get_latex(fn, math_symbol=True):
  try:
    source = inspect.getsource(fn)##获取整个模块的源代码
  except Exception:
    # Maybe running on console.
    source = dill.source.getsource(fn)

  return LatexifyVisitor(math_symbol=math_symbol).visit(ast.parse(source)) ##ast.parse把源码解析为AST节点,AST是抽象语法树,不依赖于具体的文法,不依赖于语言的细节,我们将源代码转化为AST后,可以对AST做很多的操作

LatexifyVisitor继承ast的NodeVisitor,ast.NodeVisitor是一个专门用来遍历语法树的工具,可以通过继承这个类来完成对语法树的遍历以及遍历过程中的处理。

LatexifyVisitor首先从根节点root进行遍历,在遍历的过程中,每个节点类型都有专用的类型处理函数,以"visit_" + "Node类型"为名称,如果不存在,则调用通用的的处理函数generic_visit。

在latexify的core.py直接引入astunparse,将生成的ast打印出来:

def get_latex(fn, math_symbol=True):
  try:
    source = inspect.getsource(fn)

    print(astunparse.dump(ast.parse(source)))

  except Exception:
    # Maybe running on console.
    source = dill.source.getsource(fn)

  return LatexifyVisitor(math_symbol=math_symbol).visit(ast.parse(source))

下面是test对应的ast结构:

Module(
  body=[FunctionDef(
    name='test',
    args=arguments(
      posonlyargs=[],
      args=[
        arg(
          arg='a',
          annotation=None,
          type_comment=None),
        arg(
          arg='b',
          annotation=None,
          type_comment=None)],
      vararg=None,
      kwonlyargs=[],
      kw_defaults=[],
      kwarg=None,
      defaults=[]),
    body=[Return(value=BinOp(
      left=Constant(
        value=1,
        kind=None),
      op=Sub(),
      right=BinOp(
        left=Name(
          id='a',
          ctx=Load()),
        op=Add(),
        right=Name(
          id='b',
          ctx=Load()))))],
    decorator_list=[Attribute(
      value=Name(
        id='latexify',
        ctx=Load()),
      attr='with_latex',
      ctx=Load())],
    returns=None,
    type_comment=None)],
  type_ignores=[])

首先访问根节点root,root为Moudle类型,会调用visit_Moudle函数,以此始遍历子节点FunctionDef、Return和BinOp,调用对应的visit_FunctionDef、visit_Return和vist_BinOp。

参照打印出来的python公式代码和ast结构,来分析一下整体逻辑:

vist_FunctionDef

def visit_FunctionDef(self, node):
  name_str = r'\mathrm{' + str(node.name) + '}'
  arg_strs = [self._parse_math_symbols(str(arg.arg)) for arg in node.args.args]
  body_str = self.visit(node.body[0])
  return name_str + '(' + ', '.join(arg_strs) + r')\triangleq ' + body_str

遍历FunctionDef节点后,输出为:

\mathrm{test}(a,b)\triangleq

visit_Return

def visit_Return(self, node):
  return self.visit(node.value)

Return节点的值为子节点,类型为BinOp。ast将输入的代码分为left和right,test例子中,left为常数1,right是下一个子节点,类型为BinOp,op为运算符,这里为Sub减法。看看visit_BinOp:

visit_BinOp

def visit_BinOp(self, node):
  priority = {
      ast.Add: 10,
      ast.Sub: 10,
      ast.Mult: 20,
      ast.MatMult: 20,
      ast.Div: 20,
      ast.FloorDiv: 20,
      ast.Mod: 20,
      ast.Pow: 30,
  }

  def _unwrap(child):
    return self.visit(child)

  def _wrap(child):
    latex = _unwrap(child)
    if isinstance(child, ast.BinOp):
      cp = priority[type(child.op)] if type(child.op) in priority else 100
      pp = priority[type(node.op)] if type(node.op) in priority else 100

      if cp < pp:
        return '(' + latex + ')'
    return latex

  l = node.left
  r = node.right
  reprs = {
      ast.Add: (lambda: _wrap(l) + ' + ' + _wrap(r)),
      ast.Sub: (lambda: _wrap(l) + ' - ' + _wrap(r)),
      ast.Mult: (lambda: _wrap(l) + _wrap(r)),
      ast.MatMult: (lambda: _wrap(l) + _wrap(r)),
      ast.Div: (lambda: r'\frac{' + _unwrap(l) + '}{' + _unwrap(r) + '}'),
      ast.FloorDiv: (lambda: r'\left\lfloor\frac{' + _unwrap(l) + '}{' + _unwrap(r) + r'}\right\rfloor'),
      ast.Mod: (lambda: _wrap(l) + r' \bmod ' + _wrap(r)),
      ast.Pow: (lambda: _wrap(l) + '^{' + _unwrap(r) + '}'),
  }

  if type(node.op) in reprs:
    return reprs[type(node.op)]()
  else:
    return r'\mathrm{unknown\_binop}(' + _unwrap(l) + ', ' + _unwrap(r) + ')'

ast.Add和ast.Sub设置的优先级都为10,_wrap方法通过优先级来判断是否添加括号,即:

  cp = priority[type(child.op)] if type(child.op) in priority else 100
  pp = priority[type(node.op)] if type(node.op) in priority else 100
  if cp < pp:
    return '(' + latex + ')'

test例子中child.op为Sub,node.op是right中的op为Add,优先级相同不添加括号,所以输出:

1 - a + b

遍历结束后输出:

\mathrm{test}(a, b)\triangleq 1 - a + b

这和公式实际上表达的意思南辕北辙,解决方法就是将小于改为小于等于,即

 if cp <= pp:
    return '(' + latex + ')'