博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
【TensorFlow重大升级】自动将Python代码转为TF Graph,大幅简化动态图处理!
阅读量:6180 次
发布时间:2019-06-21

本文共 3934 字,大约阅读时间需要 13 分钟。

【新智元导读】TensorFlow发布重大功能改进AutoGraph,能自动将Python代码转换为TensorFlow Graph,TF动态图处理速度大幅提升!

今天,TensorFlow团队发布新功能“AutoGraph”,能自动将Python代码(包括控制流,print () 和其他Python原生特征)转换为纯TensorFlow图代码(pure TensorFlow graph code)。

不使用 Eager Execution编写TensorFlow代码需要进行一些元编程(metaprogramming) ——先编写一个创建图(Graph)的程序,稍后再执行这个Graph。这可能令人困惑,尤其是对开发者新手来说。一些特别棘手的情况涉及更复杂的模型,比如要使用 if 和 while 的模型,或者有 print () 等副作用或接受结构化输入的模型。

为什么我们需要Graph呢?Graph允许各种优化,例如删除常见的子表达式和融合内核(fusing kernel)。再者,Graph简化了分布式训练和部署到各种环境的过程,因为它们形成了独立于平台的模型计算过程。这对于模型在多个GPU或TPU上的分布式训练尤为重要,如果你通过TensorFlow Lite、移动端、物联网等其他平台分发模型,Graph也很重要。

下面是一个很简单的、你可能希望添加到Graph里的操作:

def huber_loss(a):  if tf.abs(a) <= delta:    loss = a * a / 2  else:    loss = delta * (tf.abs(a) - delta / 2)  return loss

通过Eager Execution,只是能做到这一点,但是由于Python解释器开销(interpreter overheads)或错过的程序优化机会,此类操作可能会很慢。

为了准备执行Graph,你需要重写这个以使用像 tf.cond () 这样的结构,但那样实现起来可能会耗时耗力而且很困难。AutoGraph可以为自动执行此类转换,将动态图编程的简易性保持很低的同时,获得基于Graph执行的性能优势。

在示例中,我们可以使用 autograph.convert () 来修饰函数,AutoGraph将自动生成 graph-ready 的代码。

使用AutoGraph,这段代码:

@autograph.convert()def huber_loss(a):  if tf.abs(a) <= delta:    loss = a * a / 2  else:    loss = delta * (tf.abs(a) - delta / 2)  return loss

在执行时将变成这种样子:

def tf__huber_loss(a):  with tf.name_scope('huber_loss'):    def if_true():      with tf.name_scope('if_true'):        loss = a * a / 2        return loss,    def if_false():      with tf.name_scope('if_false'):        loss = delta * (tf.abs(a) - delta / 2)        return loss,    loss = ag__.utils.run_cond(tf.less_equal(tf.abs(a), delta), if_true,        if_false)    return loss

你可以直接调用代码,就像TensorFlow op一样:

with tf.Graph().as_default():    x_tensor = tf.constant(9.0)  # The converted function works like a regular op: tensors in, tensors out.  huber_loss_tensor = huber_loss(x_tensor)  with tf.Session() as sess:    print('TensorFlow result: %2.2f\n' % sess.run(huber_loss_tensor))

综上,AutoGraph填补了Eager Execution和Graph之间的空白。AutoGraph 将你的 eager-style Python 代码自动转换为动态图生成(graph-generating)代码。

AutoGraph不仅仅是一组有用的宏指令(macro); 它涵盖Python语言的任何部分(利用源代码转换),包括控制流、函数应用程序和赋值、生成模板代码以及重构常用的Python让它易于转换为图形。

对于任何编译器,都会担心报错信息的可读性; 为此,AutoGraph创建了报错消息和堆栈跟踪,用来显示原始源代码中的错误源,而不仅仅是显示对生成的代码的参考。

可运行的例子

那么,AutoGraph可以为你做什么呢? 以下是一些代码示例,它可以直接转换为图形代码而无需任何更改。 如果你想查看完整的代码,我们有一个notebook,你可以在Colab或GitHub上查看。

在这里,我们使用循环和分支检测Collatz猜想。 注意,我们使用AutoGraph的.to_graph()函数将其转换为图形的原因,是为了多样性而不是为了装饰。

def collatz(a):    counter = 0    while a != 1:        if a % 2 == 0:            a = a // 2        else:            a = 3 * a + 1        counter = counter + 1    return countergraph_mode_collatz = autograph.to_graph(collatz)# The code is human-readable, tooprint(autograph.to_code(collatz))collatz_tensor = graph_mode_collatz(tf.constant(n))

AutoGraph可以支持任意嵌套控制流,例如:

def f(n):  if n >= 0:    while n < 5:      n += 1      print(n)  return n

AutoGraph允许你将元素追加到循环内的数组中。 为了达到这个要求,我们使用一些AutoGraph助手,例如set_element_type 和 stack。

def f(n):  z = []  # We ask you to tell us the element dtype of the list  autograph.set_element_type(z, tf.int32)  for i in range(n):    z.append(i)  # when you're done with the list, stack it  # (this is just like np.stack)  return autograph.stack(z)

我们还支持像break,continue,甚至print和assert这样的结构。 转换后,该片段的Python将转换为图形(使用恰当的tf.Assert)。

def f(x):  assert x != 0, 'Do not pass zero!'  return x * x

能够轻松地添加循环,控制流程以及更多图表意味着可以轻松地将训练循环移动到图形中。 这个例子可以在这个notebook中找到,我们采用RNN训练循环并用一个sess.run()调用执行它。 在需要将整个训练循环传递给加速器而不是通过CPU控制器管理训练的情况下,这可能是很有用的。

AutoGraph开辟了构建和训练模型的新思路。我们期待根据开发者社区的建议为AutoGraph添加更多功能,所以请提出你的建议和问题吧!

AutoGraph和Eager Execution

在使用eager execution时,你仍然可以通过tf.contrib.eager.defun对代码的某些部分使用图执行。这要求你使用TensorFlow图形操作,如tf.cond()。 将来,AutoGraph将与defun无缝集成,以允许在简单的eager 风格的Python中创作图形代码。 当该实现可用时,你可以通过选择性地将eager代码转换为graph fragments来使用AutoGraph加速热点。

结论

AutoGraph是一款工具,可让你轻松构建直观,复杂的模型,在TensorFlow图中轻松运行。 这是一个现在在contrib中的实验工具,但我们希望尽快将其转移到核心TensorFlow中。

告诉我们您使用AutoGraph的经历! 如果你有反馈,建议或想法,请提交问题并向TensorFlow开发人员小组发送消息。

原文链接:

原文发布时间为:2018-07-19

本文来自云栖社区合作伙伴新智元,了解相关信息可以关注“AI_era”。
原文链接:

转载地址:http://ttdda.baihongyu.com/

你可能感兴趣的文章
SQL CHECK 约束
查看>>
git提交到一半关闭时
查看>>
WMware 10 Ubuntu 12.04 进入Unity模式
查看>>
简单通用的访问CVS的方法
查看>>
kbengine mmo源码(完整服务端源码+资源+完整客户端源码)
查看>>
【操作系统】实验四 主存空间的分配和回收
查看>>
Log4j 配置 的webAppRootKey参数问题
查看>>
VMware ESXi 5.0中时间配置中NTP设置
查看>>
C++中memset()函数笔记
查看>>
oracle sql 数结构表id降序
查看>>
使用cnpm加速npm
查看>>
MySql跨服务器备份数据库
查看>>
一个字典通过dictionaryWithDictionary 他们的内存指针是不同的
查看>>
HTTP 错误 500.0的解决方法。
查看>>
CCF201612-1 中间数(解法三)(100分)
查看>>
百度前端任务一学习的知识
查看>>
C# 四个字节十六进制数和单精度浮点数之间的相互转化
查看>>
JavaNIO的总结
查看>>
阿里云总监课第五期PPT下载地址
查看>>
时间属性
查看>>