From d2247c1e6130a940bf702f30289fdf71d00b53b3 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Fri, 13 Sep 2024 16:45:31 +0900 Subject: [PATCH] Normalize path returned by /userdata to always use / as separator (#4906) --- app/user_manager.py | 16 +++++---- .../prompt_server_test/user_manager_test.py | 34 +++++++++++++++++-- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/app/user_manager.py b/app/user_manager.py index 260c383b..42bc496d 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -120,14 +120,14 @@ class UserManager(): async def listuserdata(request): directory = request.rel_url.query.get('dir', '') if not directory: - return web.Response(status=400) + return web.Response(status=400, text="Directory not provided") path = self.get_request_user_filepath(request, directory) if not path: - return web.Response(status=403) + return web.Response(status=403, text="Invalid directory") if not os.path.exists(path): - return web.Response(status=404) + return web.Response(status=404, text="Directory not found") recurse = request.rel_url.query.get('recurse', '').lower() == "true" full_info = request.rel_url.query.get('full_info', '').lower() == "true" @@ -143,17 +143,21 @@ class UserManager(): if full_info: results = [ { - 'path': os.path.relpath(x, path), + 'path': os.path.relpath(x, path).replace(os.sep, '/'), 'size': os.path.getsize(x), 'modified': os.path.getmtime(x) } for x in results if os.path.isfile(x) ] else: - results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)] + results = [ + os.path.relpath(x, path).replace(os.sep, '/') + for x in results + if os.path.isfile(x) + ] split_path = request.rel_url.query.get('split', '').lower() == "true" if split_path and not full_info: - results = [[x] + x.split(os.sep) for x in results] + results = [[x] + x.split('/') for x in results] return web.json_response(results) diff --git a/tests-unit/prompt_server_test/user_manager_test.py b/tests-unit/prompt_server_test/user_manager_test.py index c71050a2..936c6bd2 100644 --- a/tests-unit/prompt_server_test/user_manager_test.py +++ b/tests-unit/prompt_server_test/user_manager_test.py @@ -2,6 +2,7 @@ import pytest import os from aiohttp import web from app.user_manager import UserManager +from unittest.mock import patch pytestmark = ( pytest.mark.asyncio @@ -53,7 +54,7 @@ async def test_listuserdata_recursive(aiohttp_client, app, tmp_path): client = await aiohttp_client(app) resp = await client.get("/userdata?dir=test_dir&recurse=true") assert resp.status == 200 - assert set(await resp.json()) == {"file1.txt", os.path.join("subdir", "file2.txt")} + assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"} async def test_listuserdata_full_info(aiohttp_client, app, tmp_path): @@ -80,7 +81,7 @@ async def test_listuserdata_split_path(aiohttp_client, app, tmp_path): resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true") assert resp.status == 200 assert await resp.json() == [ - [os.path.join("subdir", "file1.txt"), "subdir", "file1.txt"] + ["subdir/file1.txt", "subdir", "file1.txt"] ] @@ -88,3 +89,32 @@ async def test_listuserdata_invalid_directory(aiohttp_client, app): client = await aiohttp_client(app) resp = await client.get("/userdata?dir=") assert resp.status == 400 + + +async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path): + os_sep = "\\" + with patch("os.sep", os_sep): + with patch("os.path.sep", os_sep): + os.makedirs(tmp_path / "test_dir" / "subdir") + with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir&recurse=true") + assert resp.status == 200 + result = await resp.json() + assert len(result) == 1 + assert "/" in result[0] # Ensure forward slash is used + assert "\\" not in result[0] # Ensure backslash is not present + assert result[0] == "subdir/file1.txt" + + # Test with full_info + resp = await client.get( + "/userdata?dir=test_dir&recurse=true&full_info=true" + ) + assert resp.status == 200 + result = await resp.json() + assert len(result) == 1 + assert "/" in result[0]["path"] # Ensure forward slash is used + assert "\\" not in result[0]["path"] # Ensure backslash is not present + assert result[0]["path"] == "subdir/file1.txt"