Rasa is an Open source machine learning framework built in python, designed to set up conversational AI platforms, such as chatbots and voice assistants.
I have been using Rasa to develop robust chatbots for two years, and one limitation I have encountered is with the training process. When you have trained a model and you made changes to a response text in of your domain
file, the model needs to be retrained. For a model that takes too much time in the training process, this can be time-consuming.
The common response I found on GitHub issues and the Rasa forum was to retrain the model.
Since Rasa does not provide any CLI command to update a model without retraining, I will share with you in this article a simple hack I made by reverse engineering.
The Command Line Interface (CLI) of Rasa provides commands to create a project, interact with your dataset, train a model, and more. The train
command, whose details can be found by using rasa train --help
, is responsible for creating your model. So, I dove into the Rasa source code on GitHub to understand the actions behind this command.
After several hours, I was able to find these three essential files:
domain.yml
file.First, Rasa uses the Python library tarsafe to create a model archive. The following is an example of the content of a model:
📘model_name/
┣ 📁components/
┃ ┣ 📁domain_provider/ ⭐
┃ ┣ 📁finetuning_validator/
┃ ┣ 📁train_AugmentedMemoizationPolicy0/
┃ ┣ 📁train_CountVectorsFeaturizer3/
┃ ┣ 📁train_CountVectorsFeaturizer4/
┃ ┣ 📁train_DIETClassifier5/
┃ ┣ 📁train_LexicalSyntacticFeaturizer2/
┃ ┣ 📁train_RegexFeaturizer1/
┃ ┗ 📁train_RulePolicy2/
┗ ✨metadata.json ⭐
domain_provider
folder contains the domain.yml
file, which is the result of merging all of your domain files. This merging process is automatically done by Rasa when you type the rasa train
command. Later we will update this file.metadata.json
file contains metadata about the model, such as the domain as JSON, the training date, the version of Rasa used to train the model, the language, and more.domain.yml
file from the domain_provider
folder.domain.yml
file obtained in (3) to the domain_provider
folder.metadata.json
file, especially the domain property, with the JSON of the domain file obtained in (3).tarsafe
.This workaround should not be used on a production model. I use it primarily to speed up my development process.
That being said, let's start coding.
import os
import shutil
from pathlib import Path
from rasa.shared.core.domain import Domain
from rasa.model import get_latest_model
from rasa.shared.utils.io import read_json_file, dump_obj_as_json_to_file
from tarsafe import TarSafe
TRAINED_MODEL_PATH = "models/"
DOMAIN_DIRECTORY_PATH = "./"
TEMP_DIR = "rasa_toolkit/"
def update_rasa_model(trained_model_path: str = TRAINED_MODEL_PATH,
domain_directory_path: str = DOMAIN_DIRECTORY_PATH):
# Updates the latest Rasa model with the domain files in the domain_directory_path.
# trained_model_path: Path to the directory of the latest Rasa model. By default, models/
# domain_directory_path: Path to the directory containing the domain files. Default is ./
# Step 1: Clean up temp directory
if os.path.exists(TEMP_DIR):
shutil.rmtree(TEMP_DIR)
os.makedirs(TEMP_DIR)
print(f"Temporary directory {TEMP_DIR} created")
# Step 2: Load domain files
domain = Domain.from_directory(domain_directory_path)
# Step 3: Merge domain files
merged_domain_path = Path(f"{TEMP_DIR}domain.yml")
domain.persist(merged_domain_path)
# Step 4: Get the latest model and its name
model_archive_path = get_latest_model(trained_model_path)
print(f"Latest model found at {model_archive_path}")
# get model name from the path without extension
model_name = Path(model_archive_path).stem
# remove .tar from model name
model_name = model_name[:-4]
# Step 5: Unpack the model archive
storage_path = Path(f"{TEMP_DIR}{model_name}")
# if storage_path exists, delete it
if storage_path.exists():
shutil.rmtree(storage_path)
# extract all files from the model archive to storage_path
with TarSafe.open(model_archive_path, "r:gz") as tar:
tar.extractall(storage_path)
# Step 6: Remove the old domain file from the unpacked model archive
old_domain_file_path = Path(f"{storage_path}/components/domain_provider/domain.yml")
old_domain_file_path.unlink()
# Step 7: Copy new domain file to the unpacked model archive
if not os.path.exists(f"{storage_path}/components/domain_provider"):
os.makedirs(f"{storage_path}/components/domain_provider")
shutil.copyfile(f"{TEMP_DIR}domain.yml", f"{storage_path}/components/domain_provider/domain.yml")
# Step 8: Update metadata.json in the unpacked model archive
json_metadata = read_json_file(f"{storage_path}/metadata.json")
json_metadata["domain"] = domain.as_dict()
dump_obj_as_json_to_file(f"{storage_path}/metadata.json", json_metadata)
# Step 9: archive the storage_path directory using TarSafe
archive_path = Path(f"{TEMP_DIR}{model_name}.tar.gz")
with TarSafe.open(archive_path, "w:gz") as tar:
tar.add(storage_path, arcname="")
print(f"Repalce your latest model with the updated version located at {archive_path} !")
merged_domain_path.unlink()
# remove storage_path directory
shutil.rmtree(storage_path)
if __name__ == "__main__":
# Update the latest Rasa model after updating responses in the domain files
update_rasa_model()
To use the code provided above, save it into a file called main.py and place it at the root of your Rasa project. Alternatively, you can download the code from GitHub.
Below is an example of the output when running python main.py
:
(rasa-demo-py3.10) ⬢ rasa_demo ◉
> python main.py
Temporary directory rasa_toolkit/ created
Latest model found at models\20230305-195755-glum-meander.tar.gz
Repalce your latest model with the updated version located at rasa_toolkit\20230305-195755-glum-meander.tar.gz !
(rasa-demo-py3.10) ⬢ rasa_demo ◉
The training time for my latest model was 20 minutes, while the update trick only took 2 seconds. As you can notice, this represents a significant time saving⚡.
You can now replace your latest model and start a chat session using Rasa shell, or start the server to verify the changes.
The complete source code is available on GitHub. Feel free to submit a pull request for improvements!
If you found this blog post helpful for your Rasa NLU projects, please share it with others and leave a comment with your thoughts and feedback. Also, make sure to follow me on GitHub and Twitter to stay up-to-date with my latest posts.
Quick Links