Streamlit an increasingly popular tool that allows Python developers to turn data scripts into interactive web applications in a few lines of code. I recently developed and deployed a semantic search app for news articles in Chinese, and I made a mistake not caching the model loading code. The performance was abysmal, and the memory footprint was huge for a TinyBERT-4L model (had to allocate 1GB of memory for the app).
Thankfully, Randy Zwitch(@randyzwitch) from Streamlit pointed out the problem to me on Twitter:
To avoid having the model reloaded, you should use the st.cache decorator:https://t.co/Ggmd5ZrmD7— Randy Zwitch (@randyzwitch) March 8, 2021
By correctly caching the FAISS index and creating a separate API server to serve the PyTorch model (to support other use cases), I managed to increase the speed dramatically and memory consumption of the app (now only need to allocate a most 600MB of memory in total):
Here’s a more extended version of the story. I’m used to the way Flask and FastAPI handle things — the models are loaded as global variables, the caching is only required for serving repeating inputs that require heavy computation or significant I/O latency.
As an example, the following is an excerpt of this FastAPI script that returns sentence embedding vectors of the input sentences:
APP = FastAPI() if os.environ.get("MODEL", None): MODEL = SentenceEncoder(os.environ["MODEL"], device="cpu") @APP.post("/", response_model=EmbeddingsResult) def get_embeddings(text_input: TextInput): assert MODEL is not None, "MODEL is not loaded." text = text_input.text.replace("\n", " ") if text_input.t2s: text = T2S.convert(text) vector = MODEL.encode( [text], batch_size=1, show_progress_bar=False ) return EmbeddingsResult(vectors=vector.tolist())
When developing the Streamlit app, I did something similar to this:
encoder = SentenceEncoder( "streamlit_model/", device="cpu" ).eval() def main(): st.title('Veritable News Semantic Search Engine') # ... if len(query) > 10 and len(date_range) == 2: embs = encoder.encode( [query], batch_size=1, show_progress_bar=True ) #...
Because Streamlit runs the entire script for every interaction (and some other events, it seems), the model gets repeatedly loaded into the memory (one can verify that from the log), bringing devastating results. When running on my low-end Hetzner VPS instance, the app sometimes simply failed to load (stuck at the initial “connecting” animation).
The correct way to load large objects from the disk is to use the @st.cache function decorator. It’s a simple memory cache for Streamlit’s unique architecture. Streamlit stores the results from that function in the first run and provides the reference to the results in the subsequent runs.
The following is taken from this script, with some modifications (I put the sentence encoder back in):
@st.cache(allow_output_mutation=True) def load_data(): conn = sqlite3.connect("data/news.sqlite") full_ids = joblib.load("data/ids.jbl") index = faiss.read_index("data/index.faiss") default_date_range = [datetime.date(2018, 11, 28), datetime.date.today()] encoder = SentenceEncoder( "streamlit_model/", device="cpu" ).eval() return conn, full_ids, index, default_date_range, encoder
One thing to emphasize again is that Streamlit provides references to cached function results in every run. No copying is involved, so the memory requirement will not be doubled when caching is enabled. In fact, repeatedly reloading the model turns out to have larger memory footprints in my observations. The reason behind it is possibly the garbage collecting mechanism of Python. When you rapidly put new things into memory, the garbage collector might not have time to release the space used by the previous objects fast enough.