From fa3b8e939e38806f56f93c210466a6b3065c2c66 Mon Sep 17 00:00:00 2001 From: Daumas Carneiro Marina <marina.daumas-carneiro@student-cs.fr> Date: Sat, 25 Mar 2023 01:07:51 +0000 Subject: [PATCH] Update save_model.py --- save_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/save_model.py b/save_model.py index 35b2533..b80d0ae 100644 --- a/save_model.py +++ b/save_model.py @@ -1,4 +1,6 @@ import matplotlib.pyplot as plt +import os +import torch class SaveBestModel: """ @@ -27,6 +29,8 @@ class SaveBestModel: def plot_results(self, epoch_table, output_train, output_test): save_path = f'{self.images_path}/{self.feature}/{self.session_to_test}' + if not os.path.exists(save_path) + os.makedirs(save_path) plt.plot(epoch_table, output_train["loss"], label="Train") plt.plot(epoch_table, output_test["loss"], label="Test") -- GitLab