From fa3b8e939e38806f56f93c210466a6b3065c2c66 Mon Sep 17 00:00:00 2001
From: Daumas Carneiro Marina <marina.daumas-carneiro@student-cs.fr>
Date: Sat, 25 Mar 2023 01:07:51 +0000
Subject: [PATCH] Update save_model.py

---
 save_model.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/save_model.py b/save_model.py
index 35b2533..b80d0ae 100644
--- a/save_model.py
+++ b/save_model.py
@@ -1,4 +1,6 @@
 import matplotlib.pyplot as plt
+import os
+import torch
 
 class SaveBestModel:
     """
@@ -27,6 +29,8 @@ class SaveBestModel:
     def plot_results(self, epoch_table, output_train, output_test):
         
         save_path = f'{self.images_path}/{self.feature}/{self.session_to_test}'
+        if not os.path.exists(save_path)
+        os.makedirs(save_path)
 
         plt.plot(epoch_table, output_train["loss"], label="Train")
         plt.plot(epoch_table, output_test["loss"], label="Test")
-- 
GitLab