| | """
|
| | contains various utility functions for pytorch model training and saving
|
| | """
|
| | import torch
|
| | from pathlib import Path
|
| | import matplotlib.pyplot as plt
|
| | import torchvision
|
| | from PIL import Image
|
| | from torch.utils.tensorboard.writer import SummaryWriter
|
| |
|
| | def save_model(model: torch.nn.Module,
|
| | target_dir: str,
|
| | model_name: str):
|
| | """Saves a pytorch model to a target directory
|
| |
|
| | Args:
|
| | model: target pytorch model
|
| | target_dir: string of target directory path to store the saved models
|
| | model_name: a filename for the saved model. Should be included either ".pth" or ".pt" as
|
| | the file extension.
|
| | """
|
| |
|
| | target_dir_path = Path(target_dir)
|
| | target_dir_path.mkdir(parents=True, exist_ok=True)
|
| |
|
| |
|
| | assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model name should end with .pt or .pth"
|
| | model_save_path = target_dir_path / model_name
|
| |
|
| |
|
| | print(f"[INFO] Saving model to: {model_save_path}")
|
| | torch.save(obj=model.state_dict(), f=model_save_path)
|
| |
|
| | def pred_and_plot_image(
|
| | model: torch.nn.Module,
|
| | image_path: str,
|
| | class_names: list[str] = None,
|
| | transform=None,
|
| | device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
| | ):
|
| | """Makes a prediction on a target image with a trained model and plots the image.
|
| |
|
| | Args:
|
| | model (torch.nn.Module): trained PyTorch image classification model.
|
| | image_path (str): filepath to target image.
|
| | class_names (List[str], optional): different class names for target image. Defaults to None.
|
| | transform (_type_, optional): transform of target image. Defaults to None.
|
| | device (torch.device, optional): target device to compute on. Defaults to "cuda" if torch.cuda.is_available() else "cpu".
|
| |
|
| | Returns:
|
| | Matplotlib plot of target image and model prediction as title.
|
| |
|
| | Example usage:
|
| | pred_and_plot_image(model=model,
|
| | image="some_image.jpeg",
|
| | class_names=["class_1", "class_2", "class_3"],
|
| | transform=torchvision.transforms.ToTensor(),
|
| | device=device)
|
| | """
|
| |
|
| |
|
| | img_list = Image.open(image_path)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if transform:
|
| | target_image = transform(img_list)
|
| |
|
| |
|
| | model.to(device)
|
| |
|
| |
|
| | model.eval()
|
| | with torch.inference_mode():
|
| |
|
| | target_image = target_image.unsqueeze(dim=0)
|
| |
|
| |
|
| | target_image_pred = model(target_image.to(device))
|
| |
|
| |
|
| | target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
|
| |
|
| |
|
| | target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
|
| |
|
| |
|
| | plt.imshow(
|
| | target_image.squeeze().permute(1, 2, 0)
|
| | )
|
| | if class_names:
|
| | title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
|
| | else:
|
| | title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
|
| | plt.title(title)
|
| | plt.axis(False)
|
| |
|
| | def set_seeds(seed: int=42):
|
| | """Sets random sets for torch operations.
|
| |
|
| | Args:
|
| | seed (int, optional): Random seed to set. Defaults to 42.
|
| | """
|
| |
|
| | torch.manual_seed(seed)
|
| |
|
| | torch.cuda.manual_seed(seed)
|
| |
|
| |
|
| | def create_writer(experiment_name: str, model_name: str, extra: str=None) -> torch.utils.tensorboard.writer.SummaryWriter():
|
| | """
|
| | creates a torch.utils.tensorboard.writer.SummaryWriter() instance saving to a
|
| | specific log_dir.
|
| |
|
| | log_dir is a combination of runs/timestamp/experiment_name/model_name/extra.
|
| |
|
| | where timestamp is the current date in YYYY-MM-DD format.
|
| |
|
| | Args:
|
| | experiment_name (str): Name of experiment
|
| | model_name (str): model name
|
| | extra (str, optional): anything extra to add to the directory. Defaults is None
|
| |
|
| | Returns:
|
| | torch.utils.tensorboard.writer.SummaryWriter(): Instance of a writer saving to log_dir
|
| |
|
| | Examples usage:
|
| | this is gonna create writer saving to "runs/2022-06-04/data_10_percent/effnetb2/5_epochs"
|
| |
|
| | writer = create_writer(experiment_name="data_10_percent", model_name="effnetb2", extra="5_epochs")
|
| |
|
| | This is the same as:
|
| | writer = SummaryWriter(log_dir="runs/2022-06-04/data_10_percent/effnetb2/5_epochs")
|
| | """
|
| |
|
| | from datetime import datetime
|
| | import os
|
| |
|
| |
|
| | timestamp = datetime.now().strftime("%Y-%m-%d")
|
| |
|
| | if extra:
|
| |
|
| | log_dir = os.path.join("runs", timestamp, experiment_name, model_name, extra)
|
| | else:
|
| | log_dir = os.path.join("runs", timestamp, experiment_name, model_name)
|
| |
|
| | print(f"[INFO] Created SummaryWriter(), saving to: {log_dir}")
|
| |
|
| | return SummaryWriter(log_dir=log_dir)
|
| |
|