我正在使用keras 2.0.8和tensorflow 1.3.0后端.

我在类init中加载一个模型,然后用它来预测多线程.

import tensorflow as tf
from keras import backend as K
from keras.models import load_model


class CNN:
    def __init__(self,model_path):
        self.cnn_model = load_model(model_path)
        self.session = K.get_session()
        self.graph = tf.get_default_graph()

    def query_cnn(self,data):
        X = self.preproccesing(data)
        with self.session.as_default():
            with self.graph.as_default():
                return self.cnn_model.predict(X)

我初始化CNN一次,query_cnn方法从多个线程发生.

我在日志中得到的例外是:

File "/home/*/Similarity/CNN.py",line 43,in query_cnn
    return self.cnn_model.predict(X)
  File "/usr/local/lib/python3.5/dist-packages/keras/models.py",line 913,in predict
    return self.model.predict(x,batch_size=batch_size,verbose=verbose)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py",line 1713,in predict
    verbose=verbose,steps=steps)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py",line 1269,in _predict_loop
    batch_outs = f(ins_batch)
  File "/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py",line 2273,in __call__
    **self.session_kwargs)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 895,in run
    run_Metadata_ptr)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 1124,in _run
    Feed_dict_tensor,options,run_Metadata)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 1321,in _do_run
    options,line 1340,in _do_call
    raise type(e)(node_def,op,message)
tensorflow.python.framework.errors_impl.NotFoundError: PruneForTargets: Some target nodes not found: group_deps

代码在大多数情况下工作正常,它可能是多线程的一些问题.

我该如何解决?

解决方法

确保在创建其他线程之前完成图形创建.

在图表上调用finalize()可以帮助您.

def __init__(self,model_path):
        self.cnn_model = load_model(model_path)
        self.session = K.get_session()
        self.graph = tf.get_default_graph()
        self.graph.finalize()

更新1:finalize()将使您的图形为只读,以便可以安全地在多个线程中使用.作为副作用,它将帮助您找到无意的行为,有时还会发现内存泄漏,因为当您尝试修改图形时它会引发异常.

想象一下,你有一个线程可以做一个例如输入的热编码. (坏的例子:)

def preprocessing(self,data):
    one_hot_data = tf.one_hot(data,depth=self.num_classes)
    return self.session.run(one_hot_data)

如果在图表中打印对象数量,您会发现它会随着时间的推移而增加

# amount of nodes in tf graph
print(len(list(tf.get_default_graph().as_graph_def().node)))

但是,如果您首先定义图形不是这种情况(略微更好的代码):

def preprocessing(self,data):
    # run pre-created operation with self.input as placeholder
    return self.session.run(self.one_hot_data,Feed_dict={self.input: data})

更新2:根据此thread,您需要在执行多线程之前在keras模型上调用model._make_predict_function().

Keras builds the GPU function the first time you call predict(). That
way,if you never call predict,you save some time and resources.
However,the first time you call predict is slightly slower than every
other time.

更新的代码:

def __init__(self,model_path):
    self.cnn_model = load_model(model_path)
    self.cnn_model._make_predict_function() # have to initialize before threading
    self.session = K.get_session()
    self.graph = tf.get_default_graph() 
    self.graph.finalize() # make graph read-only

更新3:我做了一个预热概念的证明,因为_make_predict_function()似乎没有按预期工作.
首先我创建了一个虚拟模型:

import tensorflow as tf
from keras.layers import *
from keras.models import *

model = Sequential()
model.add(Dense(256,input_shape=(2,)))
model.add(Dense(1,activation='softmax'))

model.compile(loss='mean_squared_error',optimizer='adam')

model.save("dummymodel")

然后在另一个脚本中我加载了该模型并使其在多个线程上运行

import tensorflow as tf
from keras import backend as K
from keras.models import load_model
import threading as t
import numpy as np

K.clear_session()

class CNN:
    def __init__(self,model_path):

        self.cnn_model = load_model(model_path)
        self.cnn_model.predict(np.array([[0,0]])) # warmup
        self.session = K.get_session()
        self.graph = tf.get_default_graph()
        self.graph.finalize() # finalize

    def preproccesing(self,data):
        # dummy
        return data

    def query_cnn(self,data):
        X = self.preproccesing(data)
        with self.session.as_default():
            with self.graph.as_default():
                prediction = self.cnn_model.predict(X)
        print(prediction)
        return prediction


cnn = CNN("dummymodel")

th = t.Thread(target=cnn.query_cnn,kwargs={"data": np.random.random((500,2))})
th2 = t.Thread(target=cnn.query_cnn,2))})
th3 = t.Thread(target=cnn.query_cnn,2))})
th4 = t.Thread(target=cnn.query_cnn,2))})
th5 = t.Thread(target=cnn.query_cnn,2))})
th.start()
th2.start()
th3.start()
th4.start()
th5.start()

th2.join()
th.join()
th3.join()
th5.join()
th4.join()

评论预热和最终确定的线条我能够重现你的第一个问题

多线程 – Keras Tensorflow – 从多个线程预测时的异常的更多相关文章

  1. canvas中普通动效与粒子动效的实现代码示例

    canvas用于在网页上绘制图像、动画,可以将其理解为画布,在这个画布上构建想要的效果。本文详细的介绍了粒子特效,和普通动效进行对比,非常具有实用价值,需要的朋友可以参考下

  2. H5混合开发app如何升级的方法

    本篇文章主要介绍了H5混合开发app如何升级的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧

  3. canvas学习和滤镜实现代码

    这篇文章主要介绍了canvas学习和滤镜实现代码,利用 canvas,前端人员可以很轻松地、进行图像处理,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

  4. localStorage的过期时间设置的方法详解

    这篇文章主要介绍了localStorage的过期时间设置的方法详解的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

  5. 详解HTML5 data-* 自定义属性

    这篇文章主要介绍了详解HTML5 data-* 自定义属性的相关资料,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧

  6. HTML5 Web缓存和运用程序缓存(cookie,session)

    这篇文章主要介绍了HTML5 Web缓存和运用程序缓存(cookie,session),小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧

  7. HTML5的postMessage的使用手册

    HTML5提出了一个新的用来跨域传值的方法,即postMessage,这篇文章主要介绍了HTML5的postMessage的使用手册的相关资料,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧

  8. 教你使用Canvas处理图片的方法

    本篇文章主要介绍了教你使用Canvas处理图片的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧

  9. iOS Swift上弃用后Twitter.sharedInstance().session()?. userName的替代方案

    解决方法如果您仍在寻找解决方案,请参阅以下内容:

  10. 使用Fabric SDK iOS访问Twitter用户时间线

    我试图在这个问题上挣扎两天.我正在使用FabricSDK和Rest工具包,试图为Twitter使用不同的RestAPIWeb服务.我可以使用具有authTokenSecret,authToken和其他值的会话对象的TWTRLogInButton成功登录.当我尝试获取用户时间线时,我总是得到失败的响应,作为:{“errors”:[{“code”:215,“message”:“BadAuthentic

随机推荐

  1. 基于EJB技术的商务预订系统的开发

    用EJB结构开发的应用程序是可伸缩的、事务型的、多用户安全的。总的来说,EJB是一个组件事务监控的标准服务器端的组件模型。基于EJB技术的系统结构模型EJB结构是一个服务端组件结构,是一个层次性结构,其结构模型如图1所示。图2:商务预订系统的构架EntityBean是为了现实世界的对象建造的模型,这些对象通常是数据库的一些持久记录。

  2. Java利用POI实现导入导出Excel表格

    这篇文章主要为大家详细介绍了Java利用POI实现导入导出Excel表格,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

  3. Mybatis分页插件PageHelper手写实现示例

    这篇文章主要为大家介绍了Mybatis分页插件PageHelper手写实现示例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

  4. (jsp/html)网页上嵌入播放器(常用播放器代码整理)

    网页上嵌入播放器,只要在HTML上添加以上代码就OK了,下面整理了一些常用的播放器代码,总有一款适合你,感兴趣的朋友可以参考下哈,希望对你有所帮助

  5. Java 阻塞队列BlockingQueue详解

    本文详细介绍了BlockingQueue家庭中的所有成员,包括他们各自的功能以及常见使用场景,通过实例代码介绍了Java 阻塞队列BlockingQueue的相关知识,需要的朋友可以参考下

  6. Java异常Exception详细讲解

    异常就是不正常,比如当我们身体出现了异常我们会根据身体情况选择喝开水、吃药、看病、等 异常处理方法。 java异常处理机制是我们java语言使用异常处理机制为程序提供了错误处理的能力,程序出现的错误,程序可以安全的退出,以保证程序正常的运行等

  7. Java Bean 作用域及它的几种类型介绍

    这篇文章主要介绍了Java Bean作用域及它的几种类型介绍,Spring框架作为一个管理Bean的IoC容器,那么Bean自然是Spring中的重要资源了,那Bean的作用域又是什么,接下来我们一起进入文章详细学习吧

  8. 面试突击之跨域问题的解决方案详解

    跨域问题本质是浏览器的一种保护机制,它的初衷是为了保证用户的安全,防止恶意网站窃取数据。那怎么解决这个问题呢?接下来我们一起来看

  9. Mybatis-Plus接口BaseMapper与Services使用详解

    这篇文章主要为大家介绍了Mybatis-Plus接口BaseMapper与Services使用详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

  10. mybatis-plus雪花算法增强idworker的实现

    今天聊聊在mybatis-plus中引入分布式ID生成框架idworker,进一步增强实现生成分布式唯一ID,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

返回
顶部