From ca4b8f30e0bf40cf58dcb3f3e6118832a60348c8 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Tue, 27 Aug 2024 02:07:25 -0400 Subject: [PATCH] Cleanup empty dir if frontend zip download failed (#4574) --- app/frontend_management.py | 31 ++++++++++++-------- tests-unit/app_test/frontend_manager_test.py | 30 +++++++++++++++++++ 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index fb57b23f..9c832e46 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -8,7 +8,7 @@ import zipfile from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import TypedDict +from typing import TypedDict, Optional import requests from typing_extensions import NotRequired @@ -132,12 +132,13 @@ class FrontendManager: return match_result.group(1), match_result.group(2), match_result.group(3) @classmethod - def init_frontend_unsafe(cls, version_string: str) -> str: + def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str: """ Initializes the frontend for the specified version. Args: version_string (str): The version string. + provider (FrontEndProvider, optional): The provider to use. Defaults to None. Returns: str: The path to the initialized frontend. @@ -150,7 +151,7 @@ class FrontendManager: return cls.DEFAULT_FRONTEND_PATH repo_owner, repo_name, version = cls.parse_version_string(version_string) - provider = FrontEndProvider(repo_owner, repo_name) + provider = provider or FrontEndProvider(repo_owner, repo_name) release = provider.get_release(version) semantic_version = release["tag_name"].lstrip("v") @@ -158,15 +159,21 @@ class FrontendManager: Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version ) if not os.path.exists(web_root): - os.makedirs(web_root, exist_ok=True) - logging.info( - "Downloading frontend(%s) version(%s) to (%s)", - provider.folder_name, - semantic_version, - web_root, - ) - logging.debug(release) - download_release_asset_zip(release, destination_path=web_root) + try: + os.makedirs(web_root, exist_ok=True) + logging.info( + "Downloading frontend(%s) version(%s) to (%s)", + provider.folder_name, + semantic_version, + web_root, + ) + logging.debug(release) + download_release_asset_zip(release, destination_path=web_root) + finally: + # Clean up the directory if it is empty, i.e. the download failed + if not os.listdir(web_root): + os.rmdir(web_root) + return web_root @classmethod diff --git a/tests-unit/app_test/frontend_manager_test.py b/tests-unit/app_test/frontend_manager_test.py index 637869cf..a8df5248 100644 --- a/tests-unit/app_test/frontend_manager_test.py +++ b/tests-unit/app_test/frontend_manager_test.py @@ -1,6 +1,7 @@ import argparse import pytest from requests.exceptions import HTTPError +from unittest.mock import patch from app.frontend_management import ( FrontendManager, @@ -83,6 +84,35 @@ def test_init_frontend_invalid_provider(): with pytest.raises(HTTPError): FrontendManager.init_frontend_unsafe(version_string) +@pytest.fixture +def mock_os_functions(): + with patch('app.frontend_management.os.makedirs') as mock_makedirs, \ + patch('app.frontend_management.os.listdir') as mock_listdir, \ + patch('app.frontend_management.os.rmdir') as mock_rmdir: + mock_listdir.return_value = [] # Simulate empty directory + yield mock_makedirs, mock_listdir, mock_rmdir + +@pytest.fixture +def mock_download(): + with patch('app.frontend_management.download_release_asset_zip') as mock: + mock.side_effect = Exception("Download failed") # Simulate download failure + yield mock + +def test_finally_block(mock_os_functions, mock_download, mock_provider): + # Arrange + mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions + version_string = 'test-owner/test-repo@1.0.0' + + # Act & Assert + with pytest.raises(Exception): + FrontendManager.init_frontend_unsafe(version_string, mock_provider) + + # Assert + mock_makedirs.assert_called_once() + mock_download.assert_called_once() + mock_listdir.assert_called_once() + mock_rmdir.assert_called_once() + def test_parse_version_string(): version_string = "owner/repo@1.0.0"