我在类init中加载一个模型,然后用它来预测多线程.
import tensorflow as tf
from keras import backend as K
from keras.models import load_model
class CNN:
def __init__(self,model_path):
self.cnn_model = load_model(model_path)
self.session = K.get_session()
self.graph = tf.get_default_graph()
def query_cnn(self,data):
X = self.preproccesing(data)
with self.session.as_default():
with self.graph.as_default():
return self.cnn_model.predict(X)
我初始化CNN一次,query_cnn方法从多个线程发生.
我在日志中得到的例外是:
File "/home/*/Similarity/CNN.py",line 43,in query_cnn
return self.cnn_model.predict(X)
File "/usr/local/lib/python3.5/dist-packages/keras/models.py",line 913,in predict
return self.model.predict(x,batch_size=batch_size,verbose=verbose)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py",line 1713,in predict
verbose=verbose,steps=steps)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py",line 1269,in _predict_loop
batch_outs = f(ins_batch)
File "/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py",line 2273,in __call__
**self.session_kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 895,in run
run_Metadata_ptr)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 1124,in _run
Feed_dict_tensor,options,run_Metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 1321,in _do_run
options,line 1340,in _do_call
raise type(e)(node_def,op,message)
tensorflow.python.framework.errors_impl.NotFoundError: PruneForTargets: Some target nodes not found: group_deps
代码在大多数情况下工作正常,它可能是多线程的一些问题.
我该如何解决?
解决方法
在图表上调用finalize()可以帮助您.
def __init__(self,model_path):
self.cnn_model = load_model(model_path)
self.session = K.get_session()
self.graph = tf.get_default_graph()
self.graph.finalize()
更新1:finalize()将使您的图形为只读,以便可以安全地在多个线程中使用.作为副作用,它将帮助您找到无意的行为,有时还会发现内存泄漏,因为当您尝试修改图形时它会引发异常.
想象一下,你有一个线程可以做一个例如输入的热编码. (坏的例子:)
def preprocessing(self,data):
one_hot_data = tf.one_hot(data,depth=self.num_classes)
return self.session.run(one_hot_data)
如果在图表中打印对象数量,您会发现它会随着时间的推移而增加
# amount of nodes in tf graph print(len(list(tf.get_default_graph().as_graph_def().node)))
但是,如果您首先定义图形不是这种情况(略微更好的代码):
def preprocessing(self,data):
# run pre-created operation with self.input as placeholder
return self.session.run(self.one_hot_data,Feed_dict={self.input: data})
更新2:根据此thread,您需要在执行多线程之前在keras模型上调用model._make_predict_function().
Keras builds the GPU function the first time you call predict(). That
way,if you never call predict,you save some time and resources.
However,the first time you call predict is slightly slower than every
other time.
更新的代码:
def __init__(self,model_path):
self.cnn_model = load_model(model_path)
self.cnn_model._make_predict_function() # have to initialize before threading
self.session = K.get_session()
self.graph = tf.get_default_graph()
self.graph.finalize() # make graph read-only
更新3:我做了一个预热概念的证明,因为_make_predict_function()似乎没有按预期工作.
首先我创建了一个虚拟模型:
import tensorflow as tf
from keras.layers import *
from keras.models import *
model = Sequential()
model.add(Dense(256,input_shape=(2,)))
model.add(Dense(1,activation='softmax'))
model.compile(loss='mean_squared_error',optimizer='adam')
model.save("dummymodel")
然后在另一个脚本中我加载了该模型并使其在多个线程上运行
import tensorflow as tf
from keras import backend as K
from keras.models import load_model
import threading as t
import numpy as np
K.clear_session()
class CNN:
def __init__(self,model_path):
self.cnn_model = load_model(model_path)
self.cnn_model.predict(np.array([[0,0]])) # warmup
self.session = K.get_session()
self.graph = tf.get_default_graph()
self.graph.finalize() # finalize
def preproccesing(self,data):
# dummy
return data
def query_cnn(self,data):
X = self.preproccesing(data)
with self.session.as_default():
with self.graph.as_default():
prediction = self.cnn_model.predict(X)
print(prediction)
return prediction
cnn = CNN("dummymodel")
th = t.Thread(target=cnn.query_cnn,kwargs={"data": np.random.random((500,2))})
th2 = t.Thread(target=cnn.query_cnn,2))})
th3 = t.Thread(target=cnn.query_cnn,2))})
th4 = t.Thread(target=cnn.query_cnn,2))})
th5 = t.Thread(target=cnn.query_cnn,2))})
th.start()
th2.start()
th3.start()
th4.start()
th5.start()
th2.join()
th.join()
th3.join()
th5.join()
th4.join()
评论预热和最终确定的线条我能够重现你的第一个问题