添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
相关文章推荐
玩手机的硬盘  ·  Cache Keras trained ...·  昨天    · 
重情义的青椒  ·  Fosun Pharma ...·  1 月前    · 
飘逸的打火机  ·  Error Occurred When ...·  2 月前    · 
想出国的洋葱  ·  Extracting MICR in ...·  3 月前    · 
骑白马的包子  ·  Exposing references ...·  4 月前    · 

Hi folks,
I have trained a model (via Keras framework), exported it with model.save('model.hdf5') and now I want to integrate it with the awesome Streamlit.
Obviously, I do not want to load the model every time the end-user insert a new input, but to load it once and for all.
so my code looks something like that:

@st.cache
def load_my_model():
    model = load_model('model.hdf5')
    model.summary()
    return model
if __name__ == '__main__':
    st.title('My first app')
    sentence = st.text_input('Input your sentence here:')
    model = load_my_model()
    if sentence:
        y_hat = model.predict(sentence)

In that way I got:

“streamlit.errors.UnhashableType: <exception str() failed>”

exception.
I tried to use @st.cache(allow_output_mutation=True) and when I run a query at the streamlit page. I got:

“TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor(“input_1:0”, shape=(?, 80), dtype=int32) is not an element of this graph.”

(Of-course without any cache decorators the model is loaded and works fine)

HOW should I properly load and cache a Keras trained model?

Python ver: 2.7 (unfortunately)
Keras ver: 2.1.3
Tensorflow ver: 1.3.0
Streamlit ver: 0.55.2

Many thanks!

Hi @tc1,
Unfortunately the src that I use is not trevial to be migrated to python 3.6 (and probably therefore cannot upgrade my streamlit version).
Any suggestions with my current versions?

Thank you.

model = load_model(MODEL_PATH) model._make_predict_function() model.summary() # included to make it visible when model is reloaded session = K.get_session() return model, session if __name__ == '__main__': st.title('My first app') sentence = st.text_input('Input your sentence here:') model, session = load_model() if sentence: K.set_session(session) y_hat = model.predict(sentence)

Thanks for the solution. We should use this method instead, in Tensorflow 2 (due to the removal of session on tf2) :

import tensorflow.keras.backend as K
@st.cache(allow_output_mutation=True)
def load_model():
    model = load_model(MODEL_PATH)
    model._make_predict_function()
    model.summary()  # included to make it visible when model is reloaded
    return model
if __name__ == '__main__':
    st.title('My first app')
    sentence = st.text_input('Input your sentence here:')
    model= load_model()
    if sentence:
        y_hat = model.predict(sentence)
        Thanks for stopping by! We use cookies to help us understand how you interact with our website.
        By clicking “Accept all”, you consent to our use of cookies.  For more information, please see our privacy policy.
      

Strictly necessary cookies

These cookies are necessary for the website to function and cannot be switched off. They are usually only set in response to actions made by you which amount to a request for services, such as setting your privacy preferences, logging in or filling in forms.

Performance cookies

These cookies allow us to count visits and traffic sources so we can measure and improve the performance of our site. They help us understand how visitors move around the site and which pages are most frequently visited.

Functional cookies

These cookies are used to record your choices and settings, maintain your preferences over time and recognize you when you return to our website. These cookies help us to personalize our content for you and remember your preferences.

Targeting cookies

These cookies may be deployed to our site by our advertising partners to build a profile of your interest and provide you with content that is relevant to you, including showing you relevant ads on other websites.

Reject all Accept all