From f50e47ae916672c822e68d8f489b95b67d9c3a37 Mon Sep 17 00:00:00 2001 From: Sunandhita B Date: Fri, 29 Nov 2024 13:37:42 +0530 Subject: [PATCH 01/11] Updated api tests to improve test coverage --- api/app/handlers/v2/bot.py | 1 + api/tests/test_bot.py | 57 +++++ api/tests/test_crud.py | 450 +++++++++++++++++++++++++++++++++++++ 3 files changed, 508 insertions(+) create mode 100644 api/tests/test_bot.py create mode 100644 api/tests/test_crud.py diff --git a/api/app/handlers/v2/bot.py b/api/app/handlers/v2/bot.py index da729124..284929d3 100644 --- a/api/app/handlers/v2/bot.py +++ b/api/app/handlers/v2/bot.py @@ -32,6 +32,7 @@ async def install(install_content: JBBotCode) -> Flow: fsm_code=install_content.code, requirements_txt=install_content.requirements, index_urls=install_content.index_urls, + version = bot.version ), ), ) diff --git a/api/tests/test_bot.py b/api/tests/test_bot.py new file mode 100644 index 00000000..24ec46c9 --- /dev/null +++ b/api/tests/test_bot.py @@ -0,0 +1,57 @@ +from unittest.mock import patch +import pytest +from lib.data_models import Flow, BotConfig, BotIntent, FlowIntent, Bot +from lib.models import JBBot +from app.handlers.v2.bot import list_bots, install +from app.jb_schema import JBBotCode + +@pytest.mark.asyncio +async def test_list_bots(): + mock_bot1 = JBBot(id="test_bot_1", name="Bot1", status="active") + mock_bot2 = JBBot(id="test_bot_2", name="Bot2", status="active") + + bot_list = [mock_bot1,mock_bot2] + with patch("app.handlers.v2.bot.get_bot_list", return_value = bot_list) as mock_get_bot_list: + result = await list_bots() + + assert result == bot_list + mock_get_bot_list.assert_awaited_once() + +@pytest.mark.asyncio +async def test_install(): + + mock_jbbot_code = JBBotCode( + name = "Bot1", + status = "active", + dsl = "test_dsl", + code = "test_code", + requirements = "codaio", + index_urls = ["index_url_1","index_url_2"], + version = "1.0.0", + ) + + mock_bot1 = JBBot(id="test_bot_1", + name=mock_jbbot_code.name, + status=mock_jbbot_code.status, + dsl = mock_jbbot_code.dsl, + code = mock_jbbot_code.code, + requirements = mock_jbbot_code.requirements, + index_urls = mock_jbbot_code.index_urls, + version = mock_jbbot_code.version) + + with patch("app.handlers.v2.bot.create_bot", return_value = mock_bot1) as mock_create_bot: + result = await install(mock_jbbot_code) + + assert isinstance(result,Flow) + assert result.source == "api" + assert result.intent == FlowIntent.BOT + assert isinstance(result.bot_config,BotConfig) + assert result.bot_config.bot_id == mock_bot1.id + assert isinstance(result.bot_config.bot,Bot) + assert result.bot_config.bot.name == mock_jbbot_code.name + assert result.bot_config.bot.fsm_code == mock_jbbot_code.code + assert result.bot_config.bot.requirements_txt == mock_jbbot_code.requirements + assert result.bot_config.bot.index_urls == mock_jbbot_code.index_urls + assert result.bot_config.bot.version == mock_bot1.version + + mock_create_bot.assert_awaited_once_with(mock_jbbot_code.model_dump()) diff --git a/api/tests/test_crud.py b/api/tests/test_crud.py new file mode 100644 index 00000000..613bb5a5 --- /dev/null +++ b/api/tests/test_crud.py @@ -0,0 +1,450 @@ +import pytest +from unittest import mock +from uuid import uuid4 +from lib.db_session_handler import DBSessionHandler +from lib.models import JBUser, JBTurn, JBChannel, JBBot +from app.crud import ( + create_user, + create_turn, + get_user_by_number, + get_channel_by_id, + get_bot_list, + get_channels_by_identifier, + update_bot, + update_channel, + update_channel_by_bot_id, +) + +class AsyncContextManagerMock: + def __init__(self, session_mock): + self.session_mock = session_mock + + async def __aenter__(self): + return self.session_mock + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + +class AsyncBeginMock: + async def __aenter__(self): + pass + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + +@pytest.mark.asyncio +async def test_create_user_success(): + channel_id = "channel123" + phone_number = "1234567890" + first_name = "John" + last_name = "Doe" + + mock_session = mock.Mock() + mock_session.commit = mock.AsyncMock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + user = await create_user(channel_id, phone_number, first_name, last_name) + + assert user is not None + assert user.first_name == first_name + assert user.last_name == last_name + assert user.identifier == phone_number + assert user.channel_id == channel_id + assert isinstance(user.id, str) + assert len(user.id) == 36 + + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_create_user_db_failure(): + channel_id = "channel123" + phone_number = "1234567890" + first_name = "John" + last_name = "Doe" + + with mock.patch.object(DBSessionHandler, 'get_async_session', side_effect=Exception("Database error")): + with pytest.raises(Exception): + await create_user(channel_id, phone_number, first_name, last_name) + +@pytest.mark.asyncio +async def test_create_turn_success(): + bot_id = "test_bot_id" + channel_id = "channel123" + user_id = "test_user_id" + + mock_session = mock.Mock() + mock_session.commit = mock.AsyncMock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + turn_id = await create_turn(bot_id, channel_id, user_id) + + assert turn_id is not None + assert isinstance(turn_id, str) + assert len(turn_id) == 36 + + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_create_turn_db_failure(): + + bot_id = "test_bot_id" + channel_id = "channel123" + user_id = "test_user_id" + + with mock.patch.object(DBSessionHandler, 'get_async_session', side_effect=Exception("Database error")): + with pytest.raises(Exception): + await create_turn(bot_id, channel_id, user_id) + +@pytest.mark.asyncio +async def test_get_user_by_number_success(): + phone_number = "1234567890" + channel_id = "channel123" + + mock_user = JBUser( + id=str(uuid4()), + channel_id=channel_id, + first_name="John", + last_name="Doe", + identifier=phone_number + ) + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.first.return_value = mock_user + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await get_user_by_number(phone_number, channel_id) + + assert result.id == mock_user.id + assert result.channel_id == mock_user.channel_id + assert result.first_name == mock_user.first_name + assert result.last_name == mock_user.last_name + assert result.identifier == mock_user.identifier + mock_session.execute.assert_awaited_once() + +@pytest.mark.asyncio +async def test_get_user_by_number_failure(): + phone_number = "1234567890" + channel_id = "channel123" + + with mock.patch.object(DBSessionHandler, 'get_async_session', side_effect=Exception("Database error")): + with pytest.raises(Exception): + await get_user_by_number(phone_number, channel_id) + +@pytest.mark.asyncio +async def test_get_channel_by_id_success(): + channel_id = "channel123" + + mock_channel = JBChannel( + id = channel_id, + bot_id = "test_bot_id", + status = "active", + name = "telegram", + type = "telegram", + key = "mfvghsikzhfcdfhjsrghehssliakzjfhsk" + ) + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.first.return_value = mock_channel + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await get_channel_by_id(channel_id) + + assert result.id == mock_channel.id + assert result.bot_id == mock_channel.bot_id + assert result.status == mock_channel.status + assert result.name == mock_channel.name + assert result.type == mock_channel.type + assert result.key == mock_channel.key + + mock_session.execute.assert_awaited_once() + +@pytest.mark.asyncio +async def test_get_channel_by_id_failure(): + channel_id = "channel123" + + with mock.patch.object(DBSessionHandler, 'get_async_session', side_effect=Exception("Database error")): + with pytest.raises(Exception): + await get_channel_by_id(channel_id) + +@pytest.mark.asyncio +async def test_get_bot_list_success(): + mock_bot1 = JBBot(id=1, name="Bot1", status="active") + mock_bot2 = JBBot(id=2, name="Bot2", status="active") + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_execute_result = mock.Mock() + + mock_scalars = mock.Mock() + mock_scalars.unique.return_value = mock_scalars + mock_scalars.all.return_value = [mock_bot1, mock_bot2] + + mock_execute_result.scalars.return_value = mock_scalars + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + bot_list = await get_bot_list() + assert bot_list == [mock_bot1, mock_bot2] + mock_session.execute.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_bot_list_no_bots_found(): + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_execute_result = mock.Mock() + + mock_scalars = mock.Mock() + mock_scalars.unique.return_value = mock_scalars + mock_scalars.all.return_value = [] + + mock_execute_result.scalars.return_value = mock_scalars + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + bot_list = await get_bot_list() + assert bot_list == [] + mock_session.execute.assert_awaited_once() + +@pytest.mark.asyncio +async def test_get_channels_by_identifier_success(): + identifier = "1234567890" + channel_type = "telegram" + + mock_channel1 = JBChannel(app_id="1234567890", type="telegram", status="inactive") + mock_channel2 = JBChannel(app_id="1234567890", type="telegram", status="active") + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_execute_result = mock.Mock() + + mock_scalars = mock.Mock() + mock_scalars.unique.return_value = mock_scalars + mock_scalars.all.return_value = [mock_channel1, mock_channel2] + + mock_execute_result.scalars.return_value = mock_scalars + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + channels_list = await get_channels_by_identifier(identifier, channel_type) + assert channels_list == [mock_channel1, mock_channel2] + mock_session.execute.assert_awaited_once() + +@pytest.mark.asyncio +async def test_get_channels_by_identifier_no_channels_found(): + + identifier = "1234567890" + channel_type = "telegram" + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_execute_result = mock.Mock() + + mock_scalars = mock.Mock() + mock_scalars.unique.return_value = mock_scalars + mock_scalars.all.return_value = [] + + mock_execute_result.scalars.return_value = mock_scalars + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + channels_list = await get_channels_by_identifier(identifier, channel_type) + assert channels_list == [] + mock_session.execute.assert_awaited_once() + +@pytest.mark.asyncio +async def test_update_bot_success(): + bot_id = "test_bot_id" + data = {"status":"active"} + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_session.commit = mock.AsyncMock() + + mock_execute = mock.AsyncMock() + mock_execute.return_value.rowcount = 1 + mock_session.execute = mock_execute + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await update_bot(bot_id, data) + assert result is not None + assert result == bot_id + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_update_bot_no_bot_found(): + bot_id = None + data = {"status":"active"} + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_session.commit = mock.AsyncMock() + + mock_execute = mock.AsyncMock() + mock_execute.return_value.rowcount = 0 + mock_session.execute = mock_execute + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + result = await update_bot(bot_id, data) + assert result is None + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_update_bot_error(): + + bot_id = "test_bot_id" + data = {"status":"deleted"} + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_session.commit = mock.AsyncMock() + + mock_execute = mock.AsyncMock(side_effect=Exception("Database error")) + mock_session.execute = mock_execute + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + with pytest.raises(Exception, match="Database error"): + await update_bot(bot_id, data) + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_not_awaited() + +@pytest.mark.asyncio +async def test_update_channel_success(): + channel_id = "test_channel_id" + data = {"status": "active", "key": "ahjbdbhsbdrhiiuciuhrqtnjuifh"} + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_session.commit = mock.AsyncMock() + + mock_execute = mock.AsyncMock() + mock_execute.return_value.rowcount = 1 + mock_session.execute = mock_execute + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + result = await update_channel(channel_id, data) + assert result is not None + assert result == channel_id + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_update_channel_no_channel_found(): + channel_id = None + data = {"status": "active", "key": "ahjbdbhsbdrhiiuciuhrqtnjuifh"} + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_session.commit = mock.AsyncMock() + + mock_execute = mock.AsyncMock() + mock_execute.return_value.rowcount = 0 + mock_session.execute = mock_execute + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + result = await update_channel(channel_id, data) + assert result is None + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_update_channel_error(): + channel_id = "test_channel_id" + data = {"status": "active", "key": "ahjbdbhsbdrhiiuciuhrqtnjuifh"} + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_session.commit = mock.AsyncMock() + + mock_execute = mock.AsyncMock(side_effect=Exception("Database error")) + mock_session.execute = mock_execute + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + with pytest.raises(Exception, match="Database error"): + await update_channel(channel_id, data) + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_not_awaited() + +@pytest.mark.asyncio +async def test_update_channel_by_bot_id_success(): + bot_id = "test_bot_id" + data = {"status": "deleted"} + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_session.commit = mock.AsyncMock() + + mock_execute = mock.AsyncMock() + mock_execute.return_value.rowcount = 1 + mock_session.execute = mock_execute + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + result = await update_channel_by_bot_id(bot_id, data) + assert result is not None + assert result == bot_id + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_update_channel_by_bot_id_no_channel_found(): + bot_id = None + data = {"status": "deleted"} + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_session.commit = mock.AsyncMock() + + mock_execute = mock.AsyncMock() + mock_execute.return_value.rowcount = 0 + mock_session.execute = mock_execute + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + result = await update_channel_by_bot_id(bot_id, data) + assert result is None + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_update_channel_by_bot_id_error(): + bot_id = "test_bot_id" + data = {"status": "deleted"} + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_session.commit = mock.AsyncMock() + + mock_execute = mock.AsyncMock(side_effect=Exception("Database error")) + mock_session.execute = mock_execute + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + with pytest.raises(Exception, match="Database error"): + await update_channel_by_bot_id(bot_id, data) + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_not_awaited() \ No newline at end of file From 7b22cb5be08066a00de1ca8749fb5231ba69261b Mon Sep 17 00:00:00 2001 From: Ananya Agrawal Date: Mon, 9 Dec 2024 10:05:29 +0530 Subject: [PATCH 02/11] Added channel tests to improve coverage --- channel/tests/test_crud.py | 122 +++++++++++++++++++++++++++++++++ channel/tests/test_outgoing.py | 61 +++++++++++++++++ 2 files changed, 183 insertions(+) create mode 100644 channel/tests/test_crud.py diff --git a/channel/tests/test_crud.py b/channel/tests/test_crud.py new file mode 100644 index 00000000..37bfcaa2 --- /dev/null +++ b/channel/tests/test_crud.py @@ -0,0 +1,122 @@ +import pytest +from unittest import mock +from lib.db_session_handler import DBSessionHandler +from lib.models import JBUser, JBChannel, JBForm +from src.crud import ( + get_channel_by_turn_id, + get_form_parameters, + create_message, + get_user_by_turn_id, +) + +class AsyncContextManagerMock: + def __init__(self, session_mock): + self.session_mock = session_mock + + async def __aenter__(self): + return self.session_mock + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + +class AsyncBeginMock: + async def __aenter__(self): + pass + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + +@pytest.mark.asyncio +async def test_get_channel_by_turn_id(): + + turn_id = "test_turn_id" + + mock_channel = JBChannel( + app_id="test_number", + key="encrypted_credentials", + type="pinnacle_whatsapp", + url="https://api.pinnacle.com", + ) + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.first.return_value = mock_channel + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await get_channel_by_turn_id(turn_id) + + assert result.app_id == mock_channel.app_id + mock_session.execute.assert_awaited_once() + +@pytest.mark.asyncio +async def test_create_message_success(): + turn_id = "test_turn_id" + message_type = "text" + message = {"text": "Hi"} + is_user_sent = False + + mock_session = mock.Mock() + mock_session.commit = mock.AsyncMock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + msg_id = await create_message(turn_id, message_type, message, is_user_sent) + + assert msg_id is not None + +@pytest.mark.asyncio +async def test_get_user_by_turn_id(): + + turn_id = "test_turn_id" + + mock_user = JBUser( + id = turn_id, + ) + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.first.return_value = mock_user + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await get_user_by_turn_id(turn_id) + + assert isinstance(result, JBUser) + assert result.id == turn_id + mock_session.execute.assert_awaited_once() + +@pytest.mark.asyncio +async def test_get_form_parameters(): + + channel_id = "test_channel_id" + form_uid = "test_form_uid" + + mock_form = JBForm( + form_uid = form_uid, + channel_id = channel_id, + ) + + mock_form.parameters = {"Param": "Value"} + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.first.return_value = mock_form + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await get_form_parameters(channel_id, form_uid) + assert result == mock_form.parameters + mock_session.execute.assert_awaited_once() diff --git a/channel/tests/test_outgoing.py b/channel/tests/test_outgoing.py index 212b8a94..b663ecb3 100644 --- a/channel/tests/test_outgoing.py +++ b/channel/tests/test_outgoing.py @@ -152,3 +152,64 @@ async def test_send_message_to_user(message): channel=mock_channel, user=mock_user, message=message ) mock_create_message.assert_called_once() + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "message", + list(test_messages.values()), + ids=list(test_messages.keys()), +) +async def test_send_message_to_user_jb_user_not_found(message): + mock_channel = MagicMock( + app_id="test_number", + key="encrypted_credentials", + type="pinnacle_whatsapp", + url="https://api.pinnacle.com", + ) + + mock_get_user_by_turn_id = AsyncMock(return_value=None) + mock_get_channel_by_turn_id = AsyncMock(return_value=mock_channel) + mock_send_message = MagicMock() + mock_create_message = AsyncMock() + + turn_id = "test_turn_id" + + with patch("src.handlers.outgoing.get_user_by_turn_id", mock_get_user_by_turn_id): + with patch( + "src.handlers.outgoing.get_channel_by_turn_id", mock_get_channel_by_turn_id + ): + with patch( + "lib.channel_handler.pinnacle_whatsapp_handler.PinnacleWhatsappHandler.send_message", + mock_send_message, + ): + with patch("src.handlers.outgoing.create_message", mock_create_message): + await send_message_to_user(turn_id=turn_id, message=message) + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "message", + list(test_messages.values()), + ids=list(test_messages.keys()), +) +async def test_send_message_to_user_jb_channel_not_found(message): + mock_user = MagicMock( + identifier="1234567890", + ) + + mock_get_user_by_turn_id = AsyncMock(return_value=mock_user) + mock_get_channel_by_turn_id = AsyncMock(return_value=None) + mock_send_message = MagicMock() + mock_create_message = AsyncMock() + + turn_id = "test_turn_id" + + with patch("src.handlers.outgoing.get_user_by_turn_id", mock_get_user_by_turn_id): + with patch( + "src.handlers.outgoing.get_channel_by_turn_id", mock_get_channel_by_turn_id + ): + with patch( + "lib.channel_handler.pinnacle_whatsapp_handler.PinnacleWhatsappHandler.send_message", + mock_send_message, + ): + with patch("src.handlers.outgoing.create_message", mock_create_message): + await send_message_to_user(turn_id=turn_id, message=message) From 07f2585ac86052ef64fb0fa4a09f1e81e69a7ba4 Mon Sep 17 00:00:00 2001 From: Sunandhita B Date: Mon, 9 Dec 2024 12:36:25 +0530 Subject: [PATCH 03/11] Added api tests to improve test coverage --- api/tests/test_bot.py | 195 +++++++++++++++++++- api/tests/test_channel.py | 234 ++++++++++++++++++++++++ api/tests/test_crud.py | 265 ++++++++++++++++++++++++++- api/tests/test_extensions.py | 39 ++++ api/tests/test_routers_bot.py | 284 +++++++++++++++++++++++++++++ api/tests/test_routers_callback.py | 111 +++++++++++ api/tests/test_routers_channel.py | 192 +++++++++++++++++++ 7 files changed, 1314 insertions(+), 6 deletions(-) create mode 100644 api/tests/test_channel.py create mode 100644 api/tests/test_extensions.py create mode 100644 api/tests/test_routers_bot.py create mode 100644 api/tests/test_routers_callback.py create mode 100644 api/tests/test_routers_channel.py diff --git a/api/tests/test_bot.py b/api/tests/test_bot.py index 24ec46c9..8a663472 100644 --- a/api/tests/test_bot.py +++ b/api/tests/test_bot.py @@ -1,9 +1,9 @@ from unittest.mock import patch import pytest -from lib.data_models import Flow, BotConfig, BotIntent, FlowIntent, Bot -from lib.models import JBBot -from app.handlers.v2.bot import list_bots, install -from app.jb_schema import JBBotCode +from lib.data_models import Flow, BotConfig, FlowIntent, Bot +from lib.models import JBBot, JBChannel +from app.handlers.v2.bot import list_bots, install, add_credentials, add_channel, delete +from app.jb_schema import JBBotCode, JBChannelContent @pytest.mark.asyncio async def test_list_bots(): @@ -55,3 +55,190 @@ async def test_install(): assert result.bot_config.bot.version == mock_bot1.version mock_create_bot.assert_awaited_once_with(mock_jbbot_code.model_dump()) + +@pytest.mark.asyncio +async def test_add_credentials_success(): + bot_id = "test_bot_id" + credentials = {"key":"test_key"} + + mock_bot = JBBot(id="test_bot_id", name="mock_bot", status="active") + + with patch("app.handlers.v2.bot.get_bot_by_id", return_value = mock_bot) as mock_get_bot_by_id, \ + patch("app.handlers.v2.bot.update_bot", return_value = bot_id) as mock_update_bot: + + result = await add_credentials(bot_id,credentials) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_get_bot_by_id.assert_awaited_once_with(bot_id) + mock_update_bot.assert_awaited_once() + +@pytest.mark.asyncio +async def test_add_credentials_failure(): + bot_id = "test_bot_id" + credentials = {"key":"test_key"} + + with patch("app.handlers.v2.bot.get_bot_by_id", return_value = None) as mock_get_bot_by_id, \ + patch("app.handlers.v2.bot.update_bot", return_value = bot_id) as mock_update_bot: + + result = await add_credentials(bot_id,credentials) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'error' + assert 'message' in result + assert result.get('message') == 'Bot not found' + + mock_get_bot_by_id.assert_awaited_once_with(bot_id) + mock_update_bot.assert_not_awaited() + +@pytest.mark.asyncio +async def test_add_channel_when_bot_with_bot_id_does_not_exist(): + + bot_id = "test_bot_id" + + channel_content = JBChannelContent( + name = "test_channel_content", + type = "test_type", + url = "test_url", + app_id = "12345678", + key = "test_key", + status = "active" + ) + + with patch("app.handlers.v2.bot.get_bot_by_id", return_value = None) as mock_get_bot_by_id: + result = await add_channel(bot_id,channel_content) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'error' + assert 'message' in result + assert result.get('message') == "Bot not found" + + mock_get_bot_by_id.assert_awaited_once_with(bot_id) + +@pytest.mark.asyncio +async def test_add_channel_when_channel_already_in_use_by_bot(): + + bot_id = "test_bot_id" + + channel_content = JBChannelContent( + name = "test_channel_content", + type = "test_type", + url = "test_url", + app_id = "12345678", + key = "test_key", + status = "active" + ) + + mock_bot = JBBot(id="test_bot_id", name="Bot1", status="active") + + mock_existing_channel = JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id", + status = "active", + name = "test_channel", + type = "test_type", + key = "test_key", + app_id = "12345678", + url = "test_url" + ) + + with patch("app.handlers.v2.bot.get_bot_by_id", return_value = mock_bot) as mock_get_bot_by_id, \ + patch("app.handlers.v2.bot.get_active_channel_by_identifier", return_value = mock_existing_channel) as mock_get_active_channel_by_identifier: + + result = await add_channel(bot_id,channel_content) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'error' + assert 'message' in result + + mock_get_bot_by_id.assert_awaited_once_with(bot_id) + mock_get_active_channel_by_identifier.assert_awaited_once_with(identifier = channel_content.app_id, + channel_type = channel_content.type) + +@pytest.mark.asyncio +async def test_add_channel_when_channel_creation_is_success(): + + bot_id = "test_bot_id" + + channel_content = JBChannelContent( + name = "test_channel_content", + type = "test_type", + url = "test_url", + app_id = "12345678", + key = "test_key", + status = "active" + ) + + mock_bot = JBBot(id="test_bot_id", name="Bot1", status="active") + + mock_channel = JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id", + status = "active", + name = "test_channel", + type = "test_type", + key = "test_key", + app_id = "12345678", + url = "test_url" + ) + + with patch("app.handlers.v2.bot.get_bot_by_id", return_value = mock_bot) as mock_get_bot_by_id, \ + patch("app.handlers.v2.bot.get_active_channel_by_identifier", return_value = None) as mock_get_active_channel_by_identifier, \ + patch("app.handlers.v2.bot.create_channel",return_value = mock_channel) as mock_create_channel: + + result = await add_channel(bot_id,channel_content) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_get_bot_by_id.assert_awaited_once_with(bot_id) + mock_get_active_channel_by_identifier.assert_awaited_once_with(identifier = channel_content.app_id, + channel_type = channel_content.type) + mock_create_channel.assert_awaited_once_with(bot_id, channel_content.model_dump()) + +@pytest.mark.asyncio +async def test_delete_success(): + + bot_id = "test_bot_id" + + mock_bot = JBBot(id="test_bot_id", name="Bot1", status="active") + + bot_data = {"status": "deleted"} + channel_data = {"status": "deleted"} + + with patch("app.handlers.v2.bot.get_bot_by_id", return_value = mock_bot) as mock_get_bot_by_id, \ + patch("app.handlers.v2.bot.update_bot", return_value = bot_id) as mock_update_bot,\ + patch("app.handlers.v2.bot.update_channel_by_bot_id", return_value = bot_id) as mock_update_channel_by_bot_id: + + result = await delete(bot_id) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_get_bot_by_id.assert_awaited_once_with(bot_id) + mock_update_bot.assert_awaited_once_with(bot_id, bot_data) + mock_update_channel_by_bot_id.assert_awaited_once_with(bot_id, channel_data) + +@pytest.mark.asyncio +async def test_delete_failure(): + + bot_id = "test_bot_id" + + with patch("app.handlers.v2.bot.get_bot_by_id", return_value = None) as mock_get_bot_by_id: + + result = await delete(bot_id) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'error' + assert 'message' in result + assert result.get('message') == 'Bot not found' + + mock_get_bot_by_id.assert_awaited_once_with(bot_id) \ No newline at end of file diff --git a/api/tests/test_channel.py b/api/tests/test_channel.py new file mode 100644 index 00000000..c9e2eab1 --- /dev/null +++ b/api/tests/test_channel.py @@ -0,0 +1,234 @@ +from unittest.mock import patch +import pytest +from app.handlers.v2.channel import list_available_channels, update, activate, deactivate, delete +from lib.channel_handler import channel_map +from lib.models import JBChannel + +@pytest.mark.asyncio +async def test_list_available_channels(): + + result = await list_available_channels() + + assert isinstance(result, list) + assert result == list(channel_map.keys()) + +@pytest.mark.asyncio +async def test_update_failure_when_channel_not_found(): + + channel_id = "test_channel_id" + channel_data = {"key":"test_key", "app_id":"12345678", "name":"test_channel", "type":"test_channel_type", "url":"test_url"} + + with patch("app.handlers.v2.channel.get_channel_by_id", return_value = None) as mock_get_channel_by_id: + result = await update(channel_id, channel_data) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'error' + assert 'message' in result + assert result.get('message') == 'Channel not found' + + mock_get_channel_by_id.assert_awaited_once_with(channel_id) + +@pytest.mark.asyncio +async def test_update_failure_when_given_channel_type_not_supported_by_this_manager(): + + channel_id = "test_channel_id" + channel_data = {"key":"test_key", "app_id":"12345678", "name":"test_channel", "type":"test_channel_type", "url":"test_url"} + + mock_channel_object = JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id", + status = "active", + name = "test_channel", + type = "test_channel_type", + key = "test_key", + app_id = "12345678", + url = "test_url", + ) + + with patch("app.handlers.v2.channel.get_channel_by_id", return_value = mock_channel_object) as mock_get_channel_by_id: + result = await update(channel_id, channel_data) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'error' + assert 'message' in result + assert result.get('message') == 'Channel not supported by this manager' + + mock_get_channel_by_id.assert_awaited_once_with(channel_id) + +@pytest.mark.asyncio +async def test_update_success(): + + channel_id = "test_channel_id" + channel_data = {"key":"test_key", "app_id":"12345678", "name":"telegram", "type":"telegram", "url":"test_url"} + + mock_channel_object = JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id", + status = "active", + name = "telegram", + type = "telegram", + key = "test_key", + app_id = "12345678", + url = "test_url", + ) + + with patch("app.handlers.v2.channel.get_channel_by_id", return_value = mock_channel_object) as mock_get_channel_by_id, \ + patch("app.handlers.v2.channel.update_channel", return_value = channel_id) as mock_update_channel: + + result = await update(channel_id, channel_data) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_get_channel_by_id.assert_awaited_once_with(channel_id) + mock_update_channel.assert_awaited_once() + +@pytest.mark.asyncio +async def test_activate_failure_when_channel_not_found(): + + channel_id = "test_channel_id" + + with patch("app.handlers.v2.channel.get_channel_by_id", return_value = None) as mock_get_channel_by_id: + result = await activate(channel_id) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'error' + assert 'message' in result + assert result.get('message') == 'Channel not found' + + mock_get_channel_by_id.assert_awaited_once_with(channel_id) + +@pytest.mark.asyncio +async def test_activate_success(): + + channel_id = "test_channel_id" + channel_data = {"status": "active"} + + mock_channel_object = JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id", + status = "active", + name = "telegram", + type = "telegram", + key = "test_key", + app_id = "12345678", + url = "test_url", + ) + + with patch("app.handlers.v2.channel.get_channel_by_id", return_value = mock_channel_object) as mock_get_channel_by_id, \ + patch("app.handlers.v2.channel.update_channel", return_value = channel_id) as mock_update_channel: + + result = await activate(channel_id) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_get_channel_by_id.assert_awaited_once_with(channel_id) + mock_update_channel.assert_awaited_once_with(channel_id, channel_data) + +@pytest.mark.asyncio +async def test_deactivate_failure_when_channel_not_found(): + + channel_id = "test_channel_id" + + with patch("app.handlers.v2.channel.get_channel_by_id", return_value = None) as mock_get_channel_by_id: + result = await deactivate(channel_id) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'error' + assert 'message' in result + assert result.get('message') == 'Channel not found' + + mock_get_channel_by_id.assert_awaited_once_with(channel_id) + +@pytest.mark.asyncio +async def test_deactivate_success(): + + channel_id = "test_channel_id" + channel_data = {"status": "inactive"} + + mock_channel_object = JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id", + status = "active", + name = "telegram", + type = "telegram", + key = "test_key", + app_id = "12345678", + url = "test_url", + ) + + with patch("app.handlers.v2.channel.get_channel_by_id", return_value = mock_channel_object) as mock_get_channel_by_id, \ + patch("app.handlers.v2.channel.update_channel", return_value = channel_id) as mock_update_channel: + + result = await deactivate(channel_id) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_get_channel_by_id.assert_awaited_once_with(channel_id) + mock_update_channel.assert_awaited_once_with(channel_id, channel_data) + +@pytest.mark.asyncio +async def test_delete_failure_when_channel_not_found(): + + channel_id = "test_channel_id" + + with patch("app.handlers.v2.channel.get_channel_by_id", return_value = None) as mock_get_channel_by_id: + result = await delete(channel_id) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'error' + assert 'message' in result + assert result.get('message') == 'Channel not found' + + mock_get_channel_by_id.assert_awaited_once_with(channel_id) + + + + + + + + + + + + + +@pytest.mark.asyncio +async def test_delete_success(): + + channel_id = "test_channel_id" + channel_data = {"status": "deleted"} + + mock_channel_object = JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id", + status = "active", + name = "telegram", + type = "telegram", + key = "test_key", + app_id = "12345678", + url = "test_url", + ) + + with patch("app.handlers.v2.channel.get_channel_by_id", return_value = mock_channel_object) as mock_get_channel_by_id, \ + patch("app.handlers.v2.channel.update_channel", return_value = channel_id) as mock_update_channel: + + result = await delete(channel_id) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_get_channel_by_id.assert_awaited_once_with(channel_id) + mock_update_channel.assert_awaited_once_with(channel_id, channel_data) \ No newline at end of file diff --git a/api/tests/test_crud.py b/api/tests/test_crud.py index 613bb5a5..f899bde9 100644 --- a/api/tests/test_crud.py +++ b/api/tests/test_crud.py @@ -2,7 +2,7 @@ from unittest import mock from uuid import uuid4 from lib.db_session_handler import DBSessionHandler -from lib.models import JBUser, JBTurn, JBChannel, JBBot +from lib.models import JBUser, JBChannel, JBBot, JBWebhookReference, JBSession from app.crud import ( create_user, create_turn, @@ -13,6 +13,12 @@ update_bot, update_channel, update_channel_by_bot_id, + get_bot_by_id, + get_plugin_reference, + get_bot_chat_sessions, + create_bot, + create_channel, + get_active_channel_by_identifier ) class AsyncContextManagerMock: @@ -447,4 +453,259 @@ async def test_update_channel_by_bot_id_error(): await update_channel_by_bot_id(bot_id, data) mock_session.execute.assert_awaited_once() - mock_session.commit.assert_not_awaited() \ No newline at end of file + mock_session.commit.assert_not_awaited() + +@pytest.mark.asyncio +async def test_get_bot_by_id_success(): + bot_id = "test_bot_id" + + mock_bot = JBBot( + id = "test_bot_id", + name = "My_Bot", + status = "active", + dsl = "test_dsl", + code = "test_code", + requirements = "codaio", + index_urls = ["index-url1","index_url2"], + required_credentials = ["OPEN_API_KEY", "CODAIO"], + version = "1.0.0" + ) + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.unique.return_value.first.return_value = mock_bot + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await get_bot_by_id(bot_id) + + assert isinstance(result, JBBot) + assert result.id == mock_bot.id + assert result.name == mock_bot.name + assert result.status == mock_bot.status + assert result.dsl == mock_bot.dsl + assert result.code == mock_bot.code + assert result.requirements == mock_bot.requirements + assert result.index_urls == mock_bot.index_urls + assert result.required_credentials == mock_bot.required_credentials + assert result.version == mock_bot.version + + mock_session.execute.assert_awaited_once() + +@pytest.mark.asyncio +async def test_get_bot_by_id_failure(): + bot_id = "test_bot_id" + + with mock.patch.object(DBSessionHandler, 'get_async_session', side_effect=Exception("Database error")): + with pytest.raises(Exception): + await get_bot_by_id(bot_id) + +@pytest.mark.asyncio +async def test_get_plugin_reference_success(): + plugin_uuid = "test_id" + + mock_webhook_reference = JBWebhookReference( + id = "test_id", + turn_id = "test_turn_id" + ) + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.first.return_value = mock_webhook_reference + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await get_plugin_reference(plugin_uuid) + + assert isinstance(result, JBWebhookReference) + assert result.id == mock_webhook_reference.id + assert result.turn_id == mock_webhook_reference.turn_id + + mock_session.execute.assert_awaited_once() + +@pytest.mark.asyncio +async def test_get_plugin_reference_failure(): + plugin_uuid = "test_id" + + with mock.patch.object(DBSessionHandler, 'get_async_session', side_effect=Exception("Database error")): + with pytest.raises(Exception): + await get_plugin_reference(plugin_uuid) + +@pytest.mark.asyncio +async def test_get_bot_chat_sessions_success(): + + bot_id = "test_bot_id" + session_id = "test_session_id" + + mock_chat_session1 = JBSession(id="test_session_id", user_id="test_user_id1", channel_id="test_channel_id1") + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_execute_result = mock.Mock() + + mock_execute_result.unique.return_value.scalars.return_value.all.return_value = [mock_chat_session1] + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + result_chat_sessions = await get_bot_chat_sessions (bot_id, session_id) + + assert len(result_chat_sessions) == 1 + assert result_chat_sessions[0].id == mock_chat_session1.id + assert result_chat_sessions[0].user_id == mock_chat_session1.user_id + assert result_chat_sessions[0].channel_id == mock_chat_session1.channel_id + + mock_session.execute.assert_awaited_once() + +@pytest.mark.asyncio +async def test_get_bot_chat_sessions_failure_no_chat_sessions_found(): + + bot_id = "test_bot_id" + session_id = "test_session_id" + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_execute_result = mock.Mock() + + mock_execute_result.unique.return_value.scalars.return_value.all.return_value = [] + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + result_chat_sessions = await get_bot_chat_sessions (bot_id, session_id) + assert result_chat_sessions == [] + mock_session.execute.assert_awaited_once() + +@pytest.mark.asyncio +async def test_create_bot_success(): + + data = {'name': 'Bot1', 'status': 'active', 'dsl': 'test_dsl', 'code': 'test_code', 'requirements': 'codaio', + 'index_urls': ['index_url_1', 'index_url_2'], 'required_credentials':['OPEN_API_KEY'], 'version': '1.0.0'} + + mock_session = mock.Mock() + mock_session.commit = mock.AsyncMock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + bot = await create_bot(data) + + assert bot is not None + assert isinstance(bot, JBBot) + assert isinstance(bot.id,str) + assert len(bot.id) == 36 + assert bot.name == data.get('name') + assert bot.dsl == data.get('dsl') + assert bot.code == data.get('code') + assert bot.requirements == data.get('requirements') + assert bot.index_urls == data.get('index_urls') + assert bot.required_credentials == data.get('required_credentials') + assert bot.version == data.get('version') + + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_create_bot_failure(): + + data = {'name': 'Bot1', 'status': 'active', 'dsl': 'test_dsl', 'code': 'test_code', 'requirements': 'codaio', + 'index_urls': ['index_url_1', 'index_url_2'], 'required_credentials':['OPEN_API_KEY'], 'version': '1.0.0'} + + with mock.patch.object(DBSessionHandler, 'get_async_session', side_effect=Exception("Database error")): + with pytest.raises(Exception): + await create_bot(data) + +@pytest.mark.asyncio +async def test_create_channel_success(): + + bot_id = "test_bot_id" + data = {'name': 'telegram', 'type': 'telegram', 'key': 'test_key', 'app_id': 'test_app_id', 'url': 'test_url','status': 'active'} + + mock_session = mock.Mock() + mock_session.commit = mock.AsyncMock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + channel = await create_channel(bot_id, data) + + assert channel is not None + assert isinstance(channel, JBChannel) + assert isinstance(channel.id,str) + assert len(channel.id) == 36 + assert channel.bot_id == bot_id + assert channel.name == data.get('name') + assert channel.type == data.get('type') + assert channel.key == data.get('key') + assert channel.app_id == data.get('app_id') + assert channel.url == data.get('url') + assert channel.status == data.get('status') + + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_create_channel_failure(): + + bot_id = "test_bot_id" + data = {'name': 'telegram', 'type': 'telegram', 'key': 'test_key', 'app_id': 'test_app_id', 'url': 'test_url','status': 'active'} + + with mock.patch.object(DBSessionHandler, 'get_async_session', side_effect=Exception("Database error")): + with pytest.raises(Exception): + await create_channel(bot_id, data) + +@pytest.mark.asyncio +async def test_get_active_channel_by_identifier_success(): + + identifier = "test_app_id" + channel_type = "telegram" + + mock_channel = JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id" , + status = "active", + name = "telegram", + type = "telegram", + key = "test_key", + app_id = "test_app_id", + url = "test_url" + ) + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.unique.return_value.first.return_value = mock_channel + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await get_active_channel_by_identifier(identifier, channel_type) + + assert result is not None + assert isinstance(result, JBChannel) + assert result.id == mock_channel.id + assert result.bot_id == mock_channel.bot_id + assert result.status == mock_channel.status + assert result.name == mock_channel.name + assert result.type == mock_channel.type + assert result.key == mock_channel.key + assert result.app_id == mock_channel.app_id + assert result.url == mock_channel.url + + mock_session.execute.assert_awaited_once() + +@pytest.mark.asyncio +async def test_get_active_channel_by_identifier_failure(): + + identifier = "test_app_id" + channel_type = "telegram" + + with mock.patch.object(DBSessionHandler, 'get_async_session', side_effect=Exception("Database error")): + with pytest.raises(Exception): + await get_active_channel_by_identifier(identifier, channel_type) \ No newline at end of file diff --git a/api/tests/test_extensions.py b/api/tests/test_extensions.py new file mode 100644 index 00000000..8181a158 --- /dev/null +++ b/api/tests/test_extensions.py @@ -0,0 +1,39 @@ +# import pytest +# from unittest.mock import patch, MagicMock +# from lib.db_session_handler import DBSessionHandler +# from lib.data_models import Flow +# from lib.kafka_utils import KafkaProducer +# from app.extensions import produce_message +# import os + +# # channel_input = Channel( +# # source="api", +# # turn_id=turn_id, +# # intent=ChannelIntent.CHANNEL_IN, +# # bot_input=RestBotInput( +# # channel_name=chosen_channel.get_channel_name(), +# # headers=headers, +# # data=message_data, +# # query_params=query_params, +# # ), +# # ) +# @pytest.mark.asyncio +# @patch('api.app.extensions.KafkaProducer.from_env_vars') +# @patch.dict(os.environ, {'KAFKA_FLOW_TOPIC': 'flow', 'KAFKA_CHANNEL_TOPIC': 'channel', 'KAFKA_INDEXER_TOPIC': 'indexer'}) +# async def test_produce_message_for_flow_instance(mock_kafka_producer): + +# flow_input = Flow( +# source = "Api", +# intent = "User input" +# ) +# mock_producer = MagicMock() +# mock_kafka_producer.return_value = mock_producer + +# flow_topic = "flow_topic" + +# produce_message(flow_input) + +# mock_producer.send_message.assert_called_once_with( +# topic=flow_topic, +# value=flow_input.model_dump_json(exclude_none=True) +# ) diff --git a/api/tests/test_routers_bot.py b/api/tests/test_routers_bot.py new file mode 100644 index 00000000..62084352 --- /dev/null +++ b/api/tests/test_routers_bot.py @@ -0,0 +1,284 @@ +from fastapi import HTTPException, Request +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from lib.data_models.flow import Bot, BotConfig, BotIntent, Flow, FlowIntent +from lib.models import JBBot +from app.jb_schema import JBBotCode, JBChannelContent + +mock_extension = MagicMock() + +@pytest.mark.asyncio +async def test_get_all_bots(): + + mock_extension.reset_mock() + + mock_bot1 = JBBot(id = "test_bot_id1", name = "test_bot1", status = "active") + mock_bot2 = JBBot(id = "test_bot_id2", name = "test_bot2", status = "inactive") + + mock_bot_list = [mock_bot1, mock_bot2] + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.bot import get_all_bots + + with patch("app.routers.v2.bot.list_bots", return_value = mock_bot_list) as mock_list_bots: + + result = await get_all_bots() + + assert len(result) == len(mock_bot_list) + for i in range(len(result)): + assert result[i].id == mock_bot_list[i].id + assert result[i].name == mock_bot_list[i].name + assert result[i].status == mock_bot_list[i].status + + mock_list_bots.assert_awaited_once() + +@pytest.mark.asyncio +async def test_install_bot_success(): + + mock_extension.reset_mock() + + mock_jbbot_code = JBBotCode( + name = "Bot1", + status = "active", + dsl = "test_dsl", + code = "test_code", + requirements = "codaio", + index_urls = ["index_url_1","index_url_2"], + version = "1.0.0", + ) + + mock_flow_input = Flow( + source="api", + intent=FlowIntent.BOT, + bot_config=BotConfig( + bot_id="test_bot_id", + intent=BotIntent.INSTALL, + bot=Bot( + name=mock_jbbot_code.name, + fsm_code=mock_jbbot_code.code, + requirements_txt=mock_jbbot_code.requirements, + index_urls=mock_jbbot_code.index_urls, + version = "1.0.0" + ), + ) + ) + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.bot import install_bot + + with patch("app.routers.v2.bot.install", return_value = mock_flow_input) as mock_install: + + result = await install_bot(mock_jbbot_code) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_install.assert_awaited_once_with(mock_jbbot_code) + +@pytest.mark.asyncio +async def test_delete_bot_success(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + mock_delete_response = {"status": "success"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.bot import delete_bot + + with patch("app.routers.v2.bot.delete", return_value = mock_delete_response) as mock_delete: + + result = await delete_bot(bot_id) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_delete.assert_awaited_once_with(bot_id) + +@pytest.mark.asyncio +async def test_delete_bot_failure(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + mock_delete_response = {"status": "error", "message": "Bot not found"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.bot import delete_bot + + with patch("app.routers.v2.bot.delete", return_value = mock_delete_response) as mock_delete: + with pytest.raises(HTTPException) as exception_info: + await delete_bot(bot_id) + + assert exception_info.value.status_code == 404 + assert exception_info.value.detail == mock_delete_response["message"] + + mock_delete.assert_awaited_once_with(bot_id) + +@pytest.mark.asyncio +async def test_add_bot_credentials_success(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + + request = { + "user": "test_bot_user", + "credentials" : {"key" : "test_key"} + } + + mock_request = AsyncMock(Request) + mock_request.json.return_value = request + + mock_add_credentials_response = {"status": "success"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.bot import add_bot_credentials + + with patch("app.routers.v2.bot.add_credentials", return_value = mock_add_credentials_response) as mock_add_credentials: + + result = await add_bot_credentials(bot_id, mock_request) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_add_credentials.assert_awaited_once() + +@pytest.mark.asyncio +async def test_add_bot_credentials_failure_when_no_credentials_have_been_provided_in_request(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + request = {"user": "test_bot_user"} + + mock_request = AsyncMock(Request) + mock_request.json.return_value = request + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.bot import add_bot_credentials + + with pytest.raises(HTTPException) as exception_info: + + await add_bot_credentials(bot_id, mock_request) + + assert exception_info.value.status_code == 400 + assert exception_info.value.detail == "No credentials provided" + +@pytest.mark.asyncio +async def test_add_bot_credentials_failure_when_no_bot_found_with_given_bot_id(): + + mock_extension.reset_mock() + + bot_id = None + request = { + "user": "test_bot_user", + "credentials" : {"key" : "test_key"} + } + + mock_add_credentials_response = {"status": "error", "message": "Bot not found"} + + mock_request = AsyncMock(Request) + mock_request.json.return_value = request + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.bot import add_bot_credentials + with patch("app.routers.v2.bot.add_credentials", return_value = mock_add_credentials_response) as mock_add_credentials: + with pytest.raises(HTTPException) as exception_info: + + await add_bot_credentials(bot_id, mock_request) + + assert exception_info.value.status_code == 404 + assert exception_info.value.detail == mock_add_credentials_response["message"] + + mock_add_credentials.assert_awaited_once() + +@pytest.mark.asyncio +async def test_add_bot_channel_success(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + channel_content = JBChannelContent( + name = "telegram", + type = "telegram", + url = "test_url", + app_id = "12345678", + key = "test_key", + status = "active" + ) + + mock_add_channel_response = {"status": "success"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.bot import add_bot_channel + + with patch("app.routers.v2.bot.add_channel", return_value = mock_add_channel_response) as mock_add_channel: + + result = await add_bot_channel(bot_id, channel_content) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_add_channel.assert_awaited_once_with(bot_id, channel_content) + +@pytest.mark.asyncio +async def test_add_bot_channel_failure_when_channel_not_supported_by_this_manager(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + channel_content = JBChannelContent( + name = "test_name", + type = "test_type", + url = "test_url", + app_id = "12345678", + key = "test_key", + status = "active" + ) + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.bot import add_bot_channel + + with pytest.raises(HTTPException) as exception_info: + + await add_bot_channel(bot_id, channel_content) + + assert exception_info.value.status_code == 400 + assert exception_info.value.detail == "Channel not supported by this manager" + +@pytest.mark.asyncio +async def test_add_bot_channel_failure_when_bot_not_found(): + + mock_extension.reset_mock() + + bot_id = None + + channel_content = JBChannelContent( + name = "telegram", + type = "telegram", + url = "test_url", + app_id = "12345678", + key = "test_key", + status = "active" + ) + + mock_add_channel_response = {"status": "error", "message": "Bot not found"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + + from app.routers.v2.bot import add_bot_channel + + with patch("app.routers.v2.bot.add_channel", return_value = mock_add_channel_response) as mock_add_channel: + with pytest.raises(HTTPException) as exception_info: + + await add_bot_channel(bot_id, channel_content) + + assert exception_info.value.status_code == 404 + assert exception_info.value.detail == mock_add_channel_response["message"] + + mock_add_channel.assert_awaited_once(bot_id, channel_content) \ No newline at end of file diff --git a/api/tests/test_routers_callback.py b/api/tests/test_routers_callback.py new file mode 100644 index 00000000..409ddf81 --- /dev/null +++ b/api/tests/test_routers_callback.py @@ -0,0 +1,111 @@ +from unittest.mock import AsyncMock, patch, MagicMock +from fastapi import HTTPException, Request +import pytest +from unittest import mock +from lib.data_models.channel import Channel, ChannelIntent, RestBotInput +from lib.channel_handler.telegram_handler import TelegramHandler + +mock_extension = MagicMock() + +@pytest.mark.asyncio +async def test_callback_success(): + + mock_extension.reset_mock() + + provider = "telegram" + bot_identifier = "test_bot_identifier" + + mock_handler = mock.Mock(spec= TelegramHandler) + mock_handler.get_channel_name.return_value = "telegram" + mock_handler.is_valid_data.return_value = True + + channel_map = {"telegram": mock_handler} + + request = { + "query_params": { + "search": "fastapi", + "page": "1" + }, + "headers": { + "authorization": "test_token_value", + "x-custom-header": "test_custom_value" + }, + "user": "test_bot_user", + "message_data":"test_data", + "credentials" : {"key" : "test_key"} + } + + mock_request = AsyncMock(Request) + mock_request.json.return_value = request + mock_request.query_params = request["query_params"] + mock_request.headers = request["headers"] + + mock_channel_input = Channel( + source="api", + turn_id="test_turn_id", + intent=ChannelIntent.CHANNEL_IN, + bot_input=RestBotInput( + channel_name="telegram", + headers=request.get("headers"), + data={"message_data":"test_data"}, + query_params=request.get("query_params"), + ), + ) + mock_handle_callback = AsyncMock() + mock_handle_callback.__aiter__.return_value = iter([(None, mock_channel_input)]) + + with patch.dict("sys.modules", {"app.extensions": mock_extension}),\ + patch.dict("app.routers.v2.callback.channel_map", channel_map): + + from app.routers.v2.callback import callback + + with patch("app.routers.v2.callback.handle_callback") as mock_handle_callback: + + result_status_code = await callback(provider, bot_identifier, mock_request) + + assert result_status_code == 200 + mock_handle_callback.assert_called_once_with(bot_identifier=bot_identifier, + callback_data=request, + headers=request.get("headers"), + query_params=request.get("query_params"), + chosen_channel=mock_handler) +@pytest.mark.asyncio +async def test_callback_failure_when_no_valid_channel(): + mock_extension.reset_mock() + + provider = "invalid_channel" + bot_identifier = "test_bot_identifier" + + mock_handler = mock.Mock(spec= TelegramHandler) + mock_handler.get_channel_name.return_value = "telegram" + mock_handler.is_valid_data.return_value = True + + channel_map = {"telegram": mock_handler} + + request = { + "query_params": { + "search": "fastapi", + "page": "1" + }, + "headers": { + "authorization": "test_token_value", + "x-custom-header": "test_custom_value" + }, + "user": "test_bot_user", + "message_data":"test_data", + "credentials" : {"key" : "test_key"} + } + + mock_request = AsyncMock(Request) + mock_request.json.return_value = request + mock_request.query_params = request["query_params"] + mock_request.headers = request["headers"] + + with patch.dict("sys.modules", {"app.extensions": mock_extension}),\ + patch.dict("app.routers.v2.callback.channel_map", channel_map): + + from app.routers.v2.callback import callback + + result_status_code = await callback(provider, bot_identifier, mock_request) + + assert result_status_code == 404 \ No newline at end of file diff --git a/api/tests/test_routers_channel.py b/api/tests/test_routers_channel.py new file mode 100644 index 00000000..799e5678 --- /dev/null +++ b/api/tests/test_routers_channel.py @@ -0,0 +1,192 @@ +from fastapi import HTTPException, Request +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +mock_extension = MagicMock() + +@pytest.mark.asyncio +async def test_get_all_channels(): + mock_extension.reset_mock() + + mock_channels_list = ["pinnacle_whatsapp","telegram","custom"] + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.channel import get_all_channels + + with patch("app.routers.v2.channel.list_available_channels", return_value = mock_channels_list) as mock_list_available_channels: + result = await get_all_channels() + + assert len(result) == len(mock_channels_list) + for i in range(len(result)): + assert result[i] == mock_channels_list[i] + + mock_list_available_channels.assert_awaited_once() + +@pytest.mark.asyncio +async def test_update_channel_success(): + mock_extension.reset_mock() + + channel_id = "test_channel_id" + + request_body = {"key":"test_key", "app_id":"12345678", "name":"telegram", "type":"telegram", "url":"test_url"} + mock_request = AsyncMock(Request) + mock_request.json.return_value = request_body + + updated_info = {"status": "success"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.channel import update_channel + + with patch("app.routers.v2.channel.update", return_value = updated_info) as mock_update: + result = await update_channel(channel_id, mock_request) + + assert len(result) == 1 + assert len(result) == len(updated_info) + assert 'status' in result + assert result.get('status') == 'success' + + mock_update.assert_awaited_once_with(channel_id, request_body) + +@pytest.mark.asyncio +async def test_update_channel_failure_when_channel_type_is_invalid(): + mock_extension.reset_mock() + + channel_id = "test_channel_id" + + request_body = {"key":"test_key", "app_id":"12345678", "name":"test_name", "type":"test_type", "url":"test_url"} + mock_request = AsyncMock(Request) + mock_request.json.return_value = request_body + + updated_info = {"status": "error", "message": "Channel not supported by this manager"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.channel import update_channel + + with patch("app.routers.v2.channel.update", return_value = updated_info) as mock_update: + with pytest.raises(HTTPException) as exception_info: + await update_channel(channel_id, mock_request) + + assert exception_info.value.status_code == 404 + assert exception_info.value.detail == updated_info["message"] + + mock_update.assert_awaited_once_with(channel_id, request_body) + +@pytest.mark.asyncio +async def test_activate_channel_success(): + mock_extension.reset_mock() + + channel_id = "test_channel_id" + + updated_info = {"status": "success"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.channel import activate_channel + + with patch("app.routers.v2.channel.activate", return_value = updated_info) as mock_activate: + result = await activate_channel(channel_id) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_activate.assert_awaited_once_with(channel_id) + +@pytest.mark.asyncio +async def test_activate_channel_failure_when_channel_not_found(): + mock_extension.reset_mock() + + channel_id = "test_channel_id" + + updated_info = {"status": "error", "message": "Channel not found"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.channel import activate_channel + + with patch("app.routers.v2.channel.activate", return_value = updated_info) as mock_activate: + with pytest.raises(HTTPException) as exception_info: + await activate_channel(channel_id) + + assert exception_info.value.status_code == 404 + assert exception_info.value.detail == updated_info["message"] + + mock_activate.assert_awaited_once_with(channel_id) + +@pytest.mark.asyncio +async def test_deactivate_channel_success(): + mock_extension.reset_mock() + + channel_id = "test_channel_id" + + updated_info = {"status": "success"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.channel import deactivate_channel + + with patch("app.routers.v2.channel.deactivate", return_value = updated_info) as mock_deactivate: + result = await deactivate_channel(channel_id) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_deactivate.assert_awaited_once_with(channel_id) + +@pytest.mark.asyncio +async def test_deactivate_channel_failure_when_channel_not_found(): + mock_extension.reset_mock() + + channel_id = "test_channel_id" + + updated_info = {"status": "error", "message": "Channel not found"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.channel import deactivate_channel + + with patch("app.routers.v2.channel.deactivate", return_value = updated_info) as mock_deactivate: + with pytest.raises(HTTPException) as exception_info: + await deactivate_channel(channel_id) + + assert exception_info.value.status_code == 404 + assert exception_info.value.detail == updated_info["message"] + + mock_deactivate.assert_awaited_once_with(channel_id) + +@pytest.mark.asyncio +async def test_add_channel_success(): + mock_extension.reset_mock() + + channel_id = "test_channel_id" + + delete_response = {"status": "success"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.channel import add_channel + + with patch("app.routers.v2.channel.delete", return_value = delete_response) as mock_delete: + result = await add_channel(channel_id) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_delete.assert_awaited_once_with(channel_id) + +@pytest.mark.asyncio +async def test_add_channel_failure_when_channel_not_found(): + mock_extension.reset_mock() + + channel_id = "test_channel_id" + + delete_response = {"status": "error", "message": "Channel not found"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v2.channel import add_channel + + with patch("app.routers.v2.channel.delete", return_value = delete_response) as mock_delete: + with pytest.raises(HTTPException) as exception_info: + await add_channel(channel_id) + + assert exception_info.value.status_code == 404 + assert exception_info.value.detail == delete_response["message"] + + mock_delete.assert_awaited_once_with(channel_id) From 52f41a1e071acb3ac0e605caaa2c529c505a531e Mon Sep 17 00:00:00 2001 From: Sunandhita B Date: Mon, 9 Dec 2024 12:52:38 +0530 Subject: [PATCH 04/11] Modified api tests --- api/tests/test_bot.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/tests/test_bot.py b/api/tests/test_bot.py index 8a663472..7e5f7e8b 100644 --- a/api/tests/test_bot.py +++ b/api/tests/test_bot.py @@ -143,7 +143,8 @@ async def test_add_channel_when_channel_already_in_use_by_bot(): type = "test_type", key = "test_key", app_id = "12345678", - url = "test_url" + url = "test_url", + bot = mock_bot ) with patch("app.handlers.v2.bot.get_bot_by_id", return_value = mock_bot) as mock_get_bot_by_id, \ @@ -155,6 +156,7 @@ async def test_add_channel_when_channel_already_in_use_by_bot(): assert 'status' in result assert result.get('status') == 'error' assert 'message' in result + assert result.get('message') == f"App ID {channel_content.app_id} already in use by bot {mock_existing_channel.bot.name}" mock_get_bot_by_id.assert_awaited_once_with(bot_id) mock_get_active_channel_by_identifier.assert_awaited_once_with(identifier = channel_content.app_id, From 8093a34f8029eccf2f5c50c7a8639c03a78de832 Mon Sep 17 00:00:00 2001 From: Sunandhita B Date: Mon, 9 Dec 2024 13:00:37 +0530 Subject: [PATCH 05/11] Modified api test file test_bot --- api/tests/test_bot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/tests/test_bot.py b/api/tests/test_bot.py index 7e5f7e8b..1bca7e5d 100644 --- a/api/tests/test_bot.py +++ b/api/tests/test_bot.py @@ -64,7 +64,8 @@ async def test_add_credentials_success(): mock_bot = JBBot(id="test_bot_id", name="mock_bot", status="active") with patch("app.handlers.v2.bot.get_bot_by_id", return_value = mock_bot) as mock_get_bot_by_id, \ - patch("app.handlers.v2.bot.update_bot", return_value = bot_id) as mock_update_bot: + patch("app.handlers.v2.bot.update_bot", return_value = bot_id) as mock_update_bot, \ + patch("app.handlers.v2.bot.EncryptionHandler.encrypt_dict", return_value = "encrypted_test_key"): result = await add_credentials(bot_id,credentials) From 0723b8a1a5cd463e5e0f3dbcef6e7301d103e576 Mon Sep 17 00:00:00 2001 From: Sunandhita B Date: Mon, 9 Dec 2024 13:21:34 +0530 Subject: [PATCH 06/11] Modified api tests --- api/tests/test_bot.py | 9 ++++++--- api/tests/test_channel.py | 4 +++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/api/tests/test_bot.py b/api/tests/test_bot.py index 1bca7e5d..78cd45d2 100644 --- a/api/tests/test_bot.py +++ b/api/tests/test_bot.py @@ -65,7 +65,7 @@ async def test_add_credentials_success(): with patch("app.handlers.v2.bot.get_bot_by_id", return_value = mock_bot) as mock_get_bot_by_id, \ patch("app.handlers.v2.bot.update_bot", return_value = bot_id) as mock_update_bot, \ - patch("app.handlers.v2.bot.EncryptionHandler.encrypt_dict", return_value = "encrypted_test_key"): + patch("app.handlers.v2.bot.EncryptionHandler.encrypt_dict", return_value = "encrypted_test_key") as mock_encrypt_dict: result = await add_credentials(bot_id,credentials) @@ -75,6 +75,7 @@ async def test_add_credentials_success(): mock_get_bot_by_id.assert_awaited_once_with(bot_id) mock_update_bot.assert_awaited_once() + mock_encrypt_dict.assert_called_once_with(credentials) @pytest.mark.asyncio async def test_add_credentials_failure(): @@ -185,14 +186,15 @@ async def test_add_channel_when_channel_creation_is_success(): status = "active", name = "test_channel", type = "test_type", - key = "test_key", + key = "encrypted_test_key", app_id = "12345678", url = "test_url" ) with patch("app.handlers.v2.bot.get_bot_by_id", return_value = mock_bot) as mock_get_bot_by_id, \ patch("app.handlers.v2.bot.get_active_channel_by_identifier", return_value = None) as mock_get_active_channel_by_identifier, \ - patch("app.handlers.v2.bot.create_channel",return_value = mock_channel) as mock_create_channel: + patch("app.handlers.v2.bot.create_channel",return_value = mock_channel) as mock_create_channel, \ + patch("app.handlers.v2.bot.EncryptionHandler.encrypt_text", return_value = "encrypted_test_key") as mock_encrypt_text: result = await add_channel(bot_id,channel_content) @@ -204,6 +206,7 @@ async def test_add_channel_when_channel_creation_is_success(): mock_get_active_channel_by_identifier.assert_awaited_once_with(identifier = channel_content.app_id, channel_type = channel_content.type) mock_create_channel.assert_awaited_once_with(bot_id, channel_content.model_dump()) + mock_encrypt_text.assert_called_once() @pytest.mark.asyncio async def test_delete_success(): diff --git a/api/tests/test_channel.py b/api/tests/test_channel.py index c9e2eab1..3b5889fb 100644 --- a/api/tests/test_channel.py +++ b/api/tests/test_channel.py @@ -75,7 +75,8 @@ async def test_update_success(): ) with patch("app.handlers.v2.channel.get_channel_by_id", return_value = mock_channel_object) as mock_get_channel_by_id, \ - patch("app.handlers.v2.channel.update_channel", return_value = channel_id) as mock_update_channel: + patch("app.handlers.v2.channel.update_channel", return_value = channel_id) as mock_update_channel, \ + patch("app.handlers.v2.bot.EncryptionHandler.encrypt_text", return_value = "encrypted_test_key") as mock_encrypt_text: result = await update(channel_id, channel_data) @@ -85,6 +86,7 @@ async def test_update_success(): mock_get_channel_by_id.assert_awaited_once_with(channel_id) mock_update_channel.assert_awaited_once() + mock_encrypt_text.assert_called_once() @pytest.mark.asyncio async def test_activate_failure_when_channel_not_found(): From 60ee25cb5b790f32fb6a61d5e3cc7196abddaad3 Mon Sep 17 00:00:00 2001 From: Sunandhita B Date: Thu, 12 Dec 2024 13:11:56 +0530 Subject: [PATCH 07/11] Modified api tests to improve test coverage --- api/tests/test_bot.py | 45 ++ api/tests/test_bot_handlers.py | 209 ++++++++++ api/tests/test_channel.py | 12 - api/tests/test_crud.py | 57 ++- api/tests/test_extensions.py | 39 -- api/tests/test_routers_bot.py | 14 +- api/tests/test_routers_callback.py | 50 ++- api/tests/test_routers_v1_init.py | 631 +++++++++++++++++++++++++++++ 8 files changed, 996 insertions(+), 61 deletions(-) create mode 100644 api/tests/test_bot_handlers.py delete mode 100644 api/tests/test_extensions.py create mode 100644 api/tests/test_routers_v1_init.py diff --git a/api/tests/test_bot.py b/api/tests/test_bot.py index 78cd45d2..8b79413a 100644 --- a/api/tests/test_bot.py +++ b/api/tests/test_bot.py @@ -164,6 +164,51 @@ async def test_add_channel_when_channel_already_in_use_by_bot(): mock_get_active_channel_by_identifier.assert_awaited_once_with(identifier = channel_content.app_id, channel_type = channel_content.type) +@pytest.mark.asyncio +async def test_add_channel_when_bot_already_has_channel_of_given_type_for_given_app_id(): + + bot_id = "test_bot_id" + + channel_content = JBChannelContent( + name = "test_channel_content", + type = "test_type", + url = "test_url", + app_id = "12345678", + key = "test_key", + status = "active" + ) + + mock_bot = JBBot( + id="test_bot_id", + name="Bot1", + status="active", + channels = [JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id", + status = "active", + name = "test_channel", + type = "test_type", + key = "test_key", + app_id = "12345678", + url = "test_url", + )] + ) + + with patch("app.handlers.v2.bot.get_bot_by_id", return_value = mock_bot) as mock_get_bot_by_id, \ + patch("app.handlers.v2.bot.get_active_channel_by_identifier", return_value = None) as mock_get_active_channel_by_identifier: + + result = await add_channel(bot_id,channel_content) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'error' + assert 'message' in result + assert result.get('message') == f"Bot already has an channel of type {channel_content.type} for app ID {channel_content.app_id}" + + mock_get_bot_by_id.assert_awaited_once_with(bot_id) + mock_get_active_channel_by_identifier.assert_awaited_once_with(identifier = channel_content.app_id, + channel_type = channel_content.type) + @pytest.mark.asyncio async def test_add_channel_when_channel_creation_is_success(): diff --git a/api/tests/test_bot_handlers.py b/api/tests/test_bot_handlers.py new file mode 100644 index 00000000..f9393e1c --- /dev/null +++ b/api/tests/test_bot_handlers.py @@ -0,0 +1,209 @@ +from unittest.mock import MagicMock, patch, Mock +import pytest +from app.jb_schema import JBBotCode +from app.handlers.v1.bot_handlers import handle_deactivate_bot, handle_delete_bot, handle_install_bot, handle_update_channel +from lib.data_models.flow import Bot, BotConfig, BotIntent, Flow, FlowIntent +from lib.models import JBBot, JBChannel + +@pytest.mark.asyncio +async def test_handle_install_bot(): + + mock_jbbot_code = JBBotCode( + name = "Bot1", + status = "active", + dsl = "test_dsl", + code = "test_code", + requirements = "codaio", + index_urls = ["index_url_1","index_url_2"], + version = "1.0.0", + ) + + mock_bot_id = "test_bot_id" + + with patch("app.handlers.v1.bot_handlers.uuid.uuid4", return_value = mock_bot_id): + result = await handle_install_bot(mock_jbbot_code) + + assert isinstance(result, Flow) + assert result.source == "api" + assert result.intent == FlowIntent.BOT + assert isinstance(result.bot_config, BotConfig) + assert result.bot_config.bot_id == mock_bot_id + + assert result.bot_config.bot_id == mock_bot_id + assert result.bot_config.intent == BotIntent.INSTALL + assert isinstance(result.bot_config.bot, Bot) + assert result.bot_config.bot.name == mock_jbbot_code.name + assert result.bot_config.bot.fsm_code== mock_jbbot_code.code + assert result.bot_config.bot.requirements_txt == mock_jbbot_code.requirements + assert result.bot_config.bot.index_urls == mock_jbbot_code.index_urls + assert result.bot_config.bot.required_credentials == mock_jbbot_code.required_credentials + assert result.bot_config.bot.version == mock_jbbot_code.version + +@pytest.mark.asyncio +async def test_handle_update_channel_success(): + + channel_id = "test_channel_id" + channel_data = {"key":"test_key", "app_id":"12345678", "name":"telegram", "type":"telegram", "url":"test_url"} + + mock_channel_object = JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id", + status = "active", + name = "telegram", + type = "telegram", + key = "test_key", + app_id = "12345678", + url = "test_url", + ) + + with patch("app.handlers.v1.bot_handlers.EncryptionHandler.encrypt_text", return_value = "encrypted_test_key") as mock_encrypt_text, \ + patch("app.handlers.v1.bot_handlers.get_channel_by_id", return_value = mock_channel_object) as mock_get_channel_by_id, \ + patch("app.handlers.v1.bot_handlers.update_channel", return_value = channel_id) as mock_update_channel: + + result = await handle_update_channel(channel_id, channel_data) + + assert len(result) == 3 + assert 'status' in result + assert result.get('status') == 'success' + assert 'message' in result + assert result.get('message') == 'Channel updated' + assert 'channel' in result + assert isinstance(result.get('channel'),JBChannel) + assert result.get('channel') == mock_channel_object + + mock_get_channel_by_id.assert_awaited_once_with(channel_id) + mock_update_channel.assert_awaited_once() + +@pytest.mark.asyncio +async def test_handle_update_channel_when_channel_not_found(): + + channel_id = "test_channel_id" + channel_data = {"key":"test_key", "app_id":"12345678", "name":"telegram", "type":"telegram", "url":"test_url"} + + with patch("app.handlers.v1.bot_handlers.EncryptionHandler.encrypt_text", return_value = "encrypted_test_key") as mock_encrypt_text, \ + patch("app.handlers.v1.bot_handlers.get_channel_by_id", return_value = None) as mock_get_channel_by_id: + + result = await handle_update_channel(channel_id, channel_data) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'error' + assert 'message' in result + assert result.get('message') == 'Channel not found' + + mock_get_channel_by_id.assert_awaited_once_with(channel_id) + +@pytest.mark.asyncio +async def test_handle_delete_bot_success(): + + channel_id = "test_channel_id" + bot_id = "test_bot_id" + bot_data = {"status": "deleted"} + + mock_channel = JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id", + status = "active", + name = "test_channel", + type = "test_type", + key = "test_key", + app_id = "12345678", + url = "test_url", + ) + bot = JBBot(id="test_bot_id", name="Bot1", status="active", channels = [mock_channel]) + + updated_info = {"status": "success", "message": "Bot updated", "bot": bot} + + with patch("app.handlers.v1.bot_handlers.handle_update_bot", return_value = updated_info) as mock_handle_update_bot, \ + patch("app.handlers.v1.bot_handlers.update_channel", return_value = channel_id) as mock_update_channel: + + result = await handle_delete_bot(bot_id) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'success' + assert 'message' in result + assert result.get('message') == 'Bot deleted' + + mock_handle_update_bot.assert_awaited_once_with(bot_id, bot_data) + mock_update_channel.assert_awaited_once() + +@pytest.mark.asyncio +async def test_handle_delete_bot_when_bot_not_found(): + + bot_id = "test_bot_id" + bot_data = {"status": "deleted"} + + mock_channel = JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id", + status = "active", + name = "test_channel", + type = "test_type", + key = "test_key", + app_id = "12345678", + url = "test_url", + ) + + updated_info = {"status": "error", "message": "Bot not found"} + + with patch("app.handlers.v1.bot_handlers.handle_update_bot", return_value = updated_info) as mock_handle_update_bot: + + result = await handle_delete_bot(bot_id) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'error' + assert 'message' in result + assert result.get('message') == 'Bot not found' + + mock_handle_update_bot.assert_awaited_once_with(bot_id, bot_data) + +@pytest.mark.asyncio +async def test_handle_deactivate_bot_success(): + + channel_id = "test_channel_id" + bot_id = "test_bot_id" + + mock_channel = JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id", + status = "active", + name = "test_channel", + type = "test_type", + key = "test_key", + app_id = "12345678", + url = "test_url", + ) + bot = JBBot(id="test_bot_id", name="Bot1", status="active", channels = [mock_channel]) + + with patch("app.handlers.v1.bot_handlers.get_bot_by_id", return_value = bot) as mock_get_bot_by_id, \ + patch("app.handlers.v1.bot_handlers.update_channel", return_value = channel_id) as mock_update_channel: + + result = await handle_deactivate_bot(bot_id) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'success' + assert 'message' in result + assert result.get('message') == 'Bot deactivated' + + mock_get_bot_by_id.assert_awaited_once_with(bot_id) + mock_update_channel.assert_awaited_once() + +@pytest.mark.asyncio +async def test_handle_deactivate_bot_when_bot_not_found(): + + bot_id = "test_bot_id" + + with patch("app.handlers.v1.bot_handlers.get_bot_by_id", return_value = None) as mock_get_bot_by_id: + + result = await handle_deactivate_bot(bot_id) + + assert len(result) == 2 + assert 'status' in result + assert result.get('status') == 'error' + assert 'message' in result + assert result.get('message') == 'Bot not found' + + mock_get_bot_by_id.assert_awaited_once_with(bot_id) \ No newline at end of file diff --git a/api/tests/test_channel.py b/api/tests/test_channel.py index 3b5889fb..4431944d 100644 --- a/api/tests/test_channel.py +++ b/api/tests/test_channel.py @@ -194,18 +194,6 @@ async def test_delete_failure_when_channel_not_found(): mock_get_channel_by_id.assert_awaited_once_with(channel_id) - - - - - - - - - - - - @pytest.mark.asyncio async def test_delete_success(): diff --git a/api/tests/test_crud.py b/api/tests/test_crud.py index f899bde9..7f76647e 100644 --- a/api/tests/test_crud.py +++ b/api/tests/test_crud.py @@ -18,7 +18,8 @@ get_bot_chat_sessions, create_bot, create_channel, - get_active_channel_by_identifier + get_active_channel_by_identifier, + get_chat_history ) class AsyncContextManagerMock: @@ -708,4 +709,56 @@ async def test_get_active_channel_by_identifier_failure(): with mock.patch.object(DBSessionHandler, 'get_async_session', side_effect=Exception("Database error")): with pytest.raises(Exception): - await get_active_channel_by_identifier(identifier, channel_type) \ No newline at end of file + await get_active_channel_by_identifier(identifier, channel_type) + +@pytest.mark.asyncio +async def test_get_chat_history_success(): + bot_id = "test_bot_id" + skip = 0 + limit = 1000 + + mock_session_object = JBSession( + id = "test_session_id", + user_id = "test_user_id", + channel_id = "test_channel_id" + ) + + mock_user_object = JBUser( + id = "test_user_id", + channel_id = "test_channel_id", + first_name = "test_first_name", + last_name = "test_last_name", + identifier = "test_identifier" + ) + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.MagicMock() + mock_execute_result.__iter__.return_value = iter([iter([mock_session_object]), iter([mock_user_object])]) + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await get_chat_history(bot_id, skip, limit) + + assert result is not None + assert isinstance(result, list) + assert len(result) == 2 + + assert result[0] == [mock_session_object] + assert result[1] == [mock_user_object] + + mock_session.execute.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_chat_history_failure(): + + bot_id = "test_bot_id" + skip = 0 + limit = 1000 + + with mock.patch.object(DBSessionHandler, 'get_async_session', side_effect=Exception("Database error")): + with pytest.raises(Exception): + await get_chat_history(bot_id, skip, limit) \ No newline at end of file diff --git a/api/tests/test_extensions.py b/api/tests/test_extensions.py deleted file mode 100644 index 8181a158..00000000 --- a/api/tests/test_extensions.py +++ /dev/null @@ -1,39 +0,0 @@ -# import pytest -# from unittest.mock import patch, MagicMock -# from lib.db_session_handler import DBSessionHandler -# from lib.data_models import Flow -# from lib.kafka_utils import KafkaProducer -# from app.extensions import produce_message -# import os - -# # channel_input = Channel( -# # source="api", -# # turn_id=turn_id, -# # intent=ChannelIntent.CHANNEL_IN, -# # bot_input=RestBotInput( -# # channel_name=chosen_channel.get_channel_name(), -# # headers=headers, -# # data=message_data, -# # query_params=query_params, -# # ), -# # ) -# @pytest.mark.asyncio -# @patch('api.app.extensions.KafkaProducer.from_env_vars') -# @patch.dict(os.environ, {'KAFKA_FLOW_TOPIC': 'flow', 'KAFKA_CHANNEL_TOPIC': 'channel', 'KAFKA_INDEXER_TOPIC': 'indexer'}) -# async def test_produce_message_for_flow_instance(mock_kafka_producer): - -# flow_input = Flow( -# source = "Api", -# intent = "User input" -# ) -# mock_producer = MagicMock() -# mock_kafka_producer.return_value = mock_producer - -# flow_topic = "flow_topic" - -# produce_message(flow_input) - -# mock_producer.send_message.assert_called_once_with( -# topic=flow_topic, -# value=flow_input.model_dump_json(exclude_none=True) -# ) diff --git a/api/tests/test_routers_bot.py b/api/tests/test_routers_bot.py index 62084352..3ed086da 100644 --- a/api/tests/test_routers_bot.py +++ b/api/tests/test_routers_bot.py @@ -165,8 +165,8 @@ async def test_add_bot_credentials_failure_when_no_credentials_have_been_provide await add_bot_credentials(bot_id, mock_request) - assert exception_info.value.status_code == 400 - assert exception_info.value.detail == "No credentials provided" + assert exception_info.value.status_code == 400 + assert exception_info.value.detail == "No credentials provided" @pytest.mark.asyncio async def test_add_bot_credentials_failure_when_no_bot_found_with_given_bot_id(): @@ -191,10 +191,10 @@ async def test_add_bot_credentials_failure_when_no_bot_found_with_given_bot_id() await add_bot_credentials(bot_id, mock_request) - assert exception_info.value.status_code == 404 - assert exception_info.value.detail == mock_add_credentials_response["message"] + assert exception_info.value.status_code == 404 + assert exception_info.value.detail == mock_add_credentials_response["message"] - mock_add_credentials.assert_awaited_once() + mock_add_credentials.assert_awaited_once() @pytest.mark.asyncio async def test_add_bot_channel_success(): @@ -248,8 +248,8 @@ async def test_add_bot_channel_failure_when_channel_not_supported_by_this_manage await add_bot_channel(bot_id, channel_content) - assert exception_info.value.status_code == 400 - assert exception_info.value.detail == "Channel not supported by this manager" + assert exception_info.value.status_code == 400 + assert exception_info.value.detail == "Channel not supported by this manager" @pytest.mark.asyncio async def test_add_bot_channel_failure_when_bot_not_found(): diff --git a/api/tests/test_routers_callback.py b/api/tests/test_routers_callback.py index 409ddf81..800ef2fc 100644 --- a/api/tests/test_routers_callback.py +++ b/api/tests/test_routers_callback.py @@ -108,4 +108,52 @@ async def test_callback_failure_when_no_valid_channel(): result_status_code = await callback(provider, bot_identifier, mock_request) - assert result_status_code == 404 \ No newline at end of file + assert result_status_code == 404 + +@pytest.mark.asyncio +async def test_callback_failure_when_active_channel_not_found(): + mock_extension.reset_mock() + + provider = "telegram" + bot_identifier = "test_bot_identifier" + + request = { + "query_params": {}, + "headers": {}, + "user": "test_bot_user", + "message_data":"test_data", + "credentials" : {"key" : "test_key"} + } + + mock_request = AsyncMock(Request) + mock_request.json.return_value = request + mock_request.query_params = request["query_params"] + mock_request.headers = request["headers"] + + mock_handler = mock.Mock(spec= TelegramHandler) + mock_handler.get_channel_name.return_value = "telegram" + mock_handler.is_valid_data.return_value = True + + channel_map = {"telegram": mock_handler} + + async def mock_handle_callback(bot_identifier = bot_identifier, + callback_data = request, + headers= request["headers"], + query_params= request["query_params"], + chosen_channel= mock_handler, + ): + yield ValueError("Active channel not found"), None + + with patch.dict("sys.modules", {"app.extensions": mock_extension}), \ + patch.dict("app.routers.v2.callback.channel_map", channel_map): + + with patch("app.routers.v2.callback.handle_callback", mock_handle_callback): + + from app.routers.v2.callback import callback + + with pytest.raises(HTTPException) as exception_info: + + await callback(provider, bot_identifier, mock_request) + + assert exception_info.value.status_code == 400 + assert exception_info.value.detail == str(ValueError("Active channel not found")) \ No newline at end of file diff --git a/api/tests/test_routers_v1_init.py b/api/tests/test_routers_v1_init.py new file mode 100644 index 00000000..5e4b2437 --- /dev/null +++ b/api/tests/test_routers_v1_init.py @@ -0,0 +1,631 @@ +from io import BytesIO +from unittest import mock +from unittest.mock import patch, MagicMock, AsyncMock +from fastapi import HTTPException, Request, UploadFile +import pytest +from app.jb_schema import JBBotActivate, JBBotChannels, JBBotCode +from lib.data_models.flow import Bot, BotConfig, BotIntent, Callback, CallbackType, Flow, FlowIntent +from lib.data_models.indexer import IndexType +from lib.models import JBBot, JBChannel, JBSession, JBUser + +mock_extension = MagicMock() + +@pytest.mark.asyncio +async def test_get_bots(): + mock_extension.reset_mock() + + mock_channel1 = JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id", + status = "active", + name = "test_channel", + type = "test_type", + key = "test_key", + app_id = "12345678", + url = "test_url", + ) + + mock_channel2 = JBChannel( + id = "test_channel_id", + bot_id = "test_bot_id", + status = "inactive", + name = "test_channel", + type = "test_type", + key = "test_key", + app_id = "12345678", + url = "test_url", + ) + + mock_bot1 = JBBot(id="test_bot_1", name="Bot1", status="active", channels = [mock_channel1]) + mock_bot2 = JBBot(id="test_bot_2", name="Bot2", status="inactive", channels = [mock_channel2]) + + bots = [mock_bot1,mock_bot2] + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import get_bots + + with patch("app.routers.v1.get_bot_list", return_value = bots) as mock_get_bot_list: + result = await get_bots() + + assert isinstance(result,list) + assert len(result) == len(bots) + for item in result: + assert isinstance(item, JBBot) + assert item.status == item.channels[0].status + assert result[0] == bots[0] + assert result[1] == bots[1] + + mock_get_bot_list.assert_awaited_once() + +@pytest.mark.asyncio +async def test_get_secret_key(): + + mock_extension.reset_mock() + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import get_secret_key + + result = await get_secret_key() + + assert len(result) == 1 + assert 'secret' in result + assert result.get('secret') is not None + +@pytest.mark.asyncio +async def test_refresh_secret_key(): + mock_extension.reset_mock() + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import refresh_secret_key + + result = await refresh_secret_key() + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + +@pytest.mark.asyncio +async def test_install_bot_success(): + + mock_extension.reset_mock() + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import install_bot + from app.routers.v1 import KEYS + mock_jbbot_code = JBBotCode( + name = "Bot1", + status = "active", + dsl = "test_dsl", + code = "test_code", + requirements = "codaio", + index_urls = ["index_url_1","index_url_2"], + version = "1.0.0", + ) + + request = { + "query_params": { + "search": "fastapi", + "page": "1" + }, + "headers": { + "authorization": f"Bearer {KEYS['JBMANAGER_KEY']}", + "x-custom-header": "test_custom_value" + }, + "user": "test_bot_user", + "message_data":"test_data", + "credentials" : {"key" : "test_key"} + } + + mock_request = AsyncMock(Request) + mock_request.json.return_value = request + mock_request.headers = request["headers"] + + mock_flow_input = Flow( + source="api", + intent=FlowIntent.BOT, + bot_config=BotConfig( + bot_id="test_bot_id", + intent=BotIntent.INSTALL, + bot=Bot( + name=mock_jbbot_code.name, + fsm_code=mock_jbbot_code.code, + requirements_txt=mock_jbbot_code.requirements, + index_urls=mock_jbbot_code.index_urls, + required_credentials=mock_jbbot_code.required_credentials, + version=mock_jbbot_code.version, + ), + ), + ) + with patch("app.routers.v1.handle_install_bot", return_value = mock_flow_input) as mock_handle_install_bot: + result = await install_bot(mock_request, mock_jbbot_code) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + mock_handle_install_bot.assert_awaited_once_with(mock_jbbot_code) + +@pytest.mark.asyncio +async def test_install_bot_failure_when_authorization_header_not_provided(): + + mock_extension.reset_mock() + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import install_bot + + mock_jbbot_code = JBBotCode( + name = "Bot1", + status = "active", + dsl = "test_dsl", + code = "test_code", + requirements = "codaio", + index_urls = ["index_url_1","index_url_2"], + version = "1.0.0", + ) + + request = { + "query_params": { + "search": "fastapi", + "page": "1" + }, + "headers": { + "x-custom-header": "test_custom_value" + }, + "user": "test_bot_user", + "message_data":"test_data", + "credentials" : {"key" : "test_key"} + } + + mock_request = AsyncMock(Request) + mock_request.json.return_value = request + mock_request.headers = request["headers"] + + with pytest.raises(HTTPException) as exception_info: + await install_bot(mock_request, mock_jbbot_code) + + assert exception_info.value.status_code == 401 + assert exception_info.value.detail == "Authorization header not provided" + +@pytest.mark.asyncio +async def test_install_bot_failure_when_unauthorized(): + + mock_extension.reset_mock() + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import install_bot + + mock_jbbot_code = JBBotCode( + name = "Bot1", + status = "active", + dsl = "test_dsl", + code = "test_code", + requirements = "codaio", + index_urls = ["index_url_1","index_url_2"], + version = "1.0.0", + ) + + request = { + "query_params": { + "search": "fastapi", + "page": "1" + }, + "headers": { + "authorization": "invalid_authorization", + "x-custom-header": "test_custom_value" + }, + "user": "test_bot_user", + "message_data":"test_data", + "credentials" : {"key" : "test_key"} + } + + mock_request = AsyncMock(Request) + mock_request.json.return_value = request + mock_request.headers = request["headers"] + + with pytest.raises(HTTPException) as exception_info: + await install_bot(mock_request, mock_jbbot_code) + + assert exception_info.value.status_code == 401 + assert exception_info.value.detail == "Unauthorized" + +@pytest.mark.asyncio +async def test_activate_bot_success(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + request_body = JBBotActivate( + phone_number = "919876543210", + channels = JBBotChannels( + whatsapp = "whatsapp" + ) + ) + + activate_bot_response = {"status": "success"} + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import activate_bot + + with patch("app.routers.v1.handle_activate_bot", return_value = activate_bot_response) as mock_handle_activate_bot: + result = await activate_bot(bot_id, request_body) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + mock_handle_activate_bot.assert_awaited_once_with(bot_id=bot_id, request_body=request_body) + +@pytest.mark.asyncio +async def test_activate_bot_failure_when_no_phone_number_provided(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + request_body = JBBotActivate( + phone_number = "", + channels = JBBotChannels( + whatsapp = "whatsapp" + ) + ) + + activate_bot_response = {"status": "error", "message": "No phone number provided"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import activate_bot + + with patch("app.routers.v1.handle_activate_bot", return_value = activate_bot_response) as mock_handle_activate_bot: + with pytest.raises(HTTPException) as exception_info: + await activate_bot(bot_id, request_body) + + assert exception_info.value.status_code == 400 + assert exception_info.value.detail == activate_bot_response["message"] + + mock_handle_activate_bot.assert_awaited_once_with(bot_id=bot_id, request_body=request_body) + +@pytest.mark.asyncio +async def test_get_bot_success(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + + updated_info = {"status": "success", "message": "Bot deactivated"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import get_bot + + with patch("app.routers.v1.handle_deactivate_bot", return_value = updated_info) as mock_handle_deactivate_bot: + result = await get_bot(bot_id) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + mock_handle_deactivate_bot.assert_awaited_once_with(bot_id) + +@pytest.mark.asyncio +async def test_get_bot_failure(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + + updated_info = {"status": "error", "message": "Bot not found"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import get_bot + + with patch("app.routers.v1.handle_deactivate_bot", return_value = updated_info) as mock_handle_deactivate_bot: + with pytest.raises(HTTPException) as exception_info: + await get_bot(bot_id) + + assert exception_info.value.status_code == 404 + assert exception_info.value.detail == updated_info["message"] + + mock_handle_deactivate_bot.assert_awaited_once_with(bot_id) + +@pytest.mark.asyncio +async def test_delete_bot_success(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + + updated_info = {"status": "success", "message": "Bot deleted"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import delete_bot + + with patch("app.routers.v1.handle_delete_bot", return_value = updated_info) as mock_handle_delete_bot: + result = await delete_bot(bot_id) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + mock_handle_delete_bot.assert_awaited_once_with(bot_id) + +@pytest.mark.asyncio +async def test_delete_bot_failure(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + + updated_info = {"status": "error", "message": "Bot not found"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import delete_bot + + with patch("app.routers.v1.handle_delete_bot", return_value = updated_info) as mock_handle_delete_bot: + with pytest.raises(HTTPException) as exception_info: + await delete_bot(bot_id) + + assert exception_info.value.status_code == 404 + assert exception_info.value.detail == updated_info["message"] + + mock_handle_delete_bot.assert_awaited_once_with(bot_id) + +@pytest.mark.asyncio +async def test_add_bot_configuraton_success(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + request = { + "query_params": { + "search": "fastapi", + "page": "1" + }, + "headers": { + "authorization": "test_authorization", + "x-custom-header": "test_custom_value" + }, + "user": "test_bot_user", + "message_data":"test_data", + "credentials" : {"key" : "test_key"}, + "config_env" : {} + } + + mock_request = AsyncMock(Request) + mock_request.json.return_value = request + mock_request.credentials = request["credentials"] + mock_request.config_env = request["config_env"] + + updated_info ={"status": "success", "message": "Bot updated", "bot":"sample_bot"} + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import add_bot_configuraton + + with patch("app.routers.v1.handle_update_bot", return_value = updated_info) as mock_handle_update_bot: + result = await add_bot_configuraton(bot_id, mock_request) + + assert len(result) == 1 + assert 'status' in result + assert result.get('status') == 'success' + + mock_handle_update_bot.assert_awaited_once() + +@pytest.mark.asyncio +async def test_add_bot_configuraton_failure_when_no_credentials_and_no_config_env_provided(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + request = { + "query_params": { + "search": "fastapi", + "page": "1" + }, + "headers": { + "authorization": "test_authorization", + "x-custom-header": "test_custom_value" + }, + "user": "test_bot_user", + "message_data":"test_data", + } + + mock_request = AsyncMock(Request) + mock_request.json.return_value = request + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import add_bot_configuraton + + with pytest.raises(HTTPException) as exception_info: + await add_bot_configuraton(bot_id, mock_request) + + assert exception_info.value.status_code == 400 + assert exception_info.value.detail == "No credentials or config_env provided" + +@pytest.mark.asyncio +async def test_add_bot_configuraton_failure_when_bot_not_found(): + + mock_extension.reset_mock() + + bot_id = "invalid_bot_id" + request = { + "query_params": { + "search": "fastapi", + "page": "1" + }, + "headers": { + "authorization": "test_authorization", + "x-custom-header": "test_custom_value" + }, + "user": "test_bot_user", + "message_data":"test_data", + "credentials" : {"key" : "test_key"}, + "config_env" : {} + } + + mock_request = AsyncMock(Request) + mock_request.json.return_value = request + mock_request.credentials = request["credentials"] + mock_request.config_env = request["config_env"] + + updated_info = {"status": "error", "message": "Bot not found"} + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import add_bot_configuraton + + with patch("app.routers.v1.handle_update_bot", return_value = updated_info) as mock_handle_update_bot: + with pytest.raises(HTTPException) as exception_info: + await add_bot_configuraton(bot_id, mock_request) + + assert exception_info.value.status_code == 404 + assert exception_info.value.detail == updated_info["message"] + + mock_handle_update_bot.assert_awaited_once() + +@pytest.mark.asyncio +async def test_get_session(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + session_id = "test_session_id" + + mock_chat_session1 = JBSession(id="test_session_id1", user_id="test_user_id1", channel_id="test_channel_id1") + mock_chat_session2 = JBSession(id="test_session_id2", user_id="test_user_id2", channel_id="test_channel_id2") + + sessions = [mock_chat_session1,mock_chat_session2] + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import get_session + + with patch("app.routers.v1.get_bot_chat_sessions", return_value = sessions) as mock_get_bot_chat_sessions: + result = await get_session(bot_id, session_id) + + assert isinstance(result,list) + assert len(result) == len(sessions) + for item in result: + assert isinstance(item,JBSession) + assert item in sessions + + mock_get_bot_chat_sessions.assert_awaited_once_with(bot_id, session_id) + +@pytest.mark.asyncio +async def test_get_chats(): + + mock_extension.reset_mock() + + bot_id = "test_bot_id" + + mock_session_object = JBSession( + id = "test_session_id", + user_id = "test_user_id", + channel_id = "test_channel_id" + ) + + mock_user_object = JBUser( + id = "test_user_id", + channel_id = "test_channel_id", + first_name = "test_first_name", + last_name = "test_last_name", + identifier = "test_identifier" + ) + + chats = [mock_session_object, mock_user_object] + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import get_chats + + with patch("app.routers.v1.get_chat_history", return_value = chats) as mock_get_chat_history: + result = await get_chats(bot_id) + + assert isinstance(result,list) + assert len(result) == len(chats) + for item in result: + assert isinstance(item,JBSession) or isinstance(item,JBUser) + assert item in chats + + mock_get_chat_history.assert_awaited_once_with(bot_id) + +@pytest.mark.asyncio +async def test_index_data(): + + test_file_1 = UploadFile(filename="test_file_1.txt", file=BytesIO(b"test content for 1st file")) + test_file_2 = UploadFile(filename="test_file_2.txt", file=BytesIO(b"test content for 2nd file")) + + indexer_type = IndexType.default + collection_name = "test_collection" + files = [test_file_1, test_file_2] + indexing_chunk_size = 4000 + indexing_chunk_overlap_size = 200 + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import index_data + from lib.file_storage.handler import StorageHandler + + mock_storage_instance = MagicMock() + + with patch("app.routers.v1.StorageHandler.get_async_instance", return_value = mock_storage_instance): + + mock_storage_instance.write_file = AsyncMock() + + result = await index_data(indexer_type,collection_name, files, indexing_chunk_size, indexing_chunk_overlap_size) + + assert len(result) == 1 + assert 'message' in result + assert result.get('message') == f"Indexing started for the files in {collection_name}" + mock_storage_instance.write_file.assert_called() + +@pytest.mark.asyncio +async def test_plugin_webhook_success(): + + mock_extension.reset_mock() + + mock_webhook_data = b'{"data": "sample_body_data", "additional_field": "extra_value"}' + mock_webhook_data_decoded = '{"data": "sample_body_data", "additional_field": "extra_value"}' + + mock_request = AsyncMock(Request) + mock_request.body.return_value = mock_webhook_data + + flow_input = Flow( + source="api", + intent=FlowIntent.CALLBACK, + callback=Callback( + turn_id="test_turn_id", + callback_type=CallbackType.EXTERNAL, + external=mock_webhook_data, + ), + ) + + async def mock_handle_webhook(mock_webhook_data_decoded): + yield flow_input + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import plugin_webhook + + with patch("app.routers.v1.handle_webhook", mock_handle_webhook): + + result_status_code = await plugin_webhook(mock_request) + + assert result_status_code == 200 + +@pytest.mark.asyncio +async def test_plugin_webhook_failure(): + + mock_extension.reset_mock() + + mock_webhook_data = b'{"data": "sample_body_data", "additional_field": "extra_value"}' + mock_webhook_data_decoded = '{"data": "sample_body_data", "additional_field": "extra_value"}' + + mock_request = AsyncMock(Request) + mock_request.body.return_value = mock_webhook_data + + flow_input = Flow( + source="api", + intent=FlowIntent.CALLBACK, + callback=Callback( + turn_id="test_turn_id", + callback_type=CallbackType.EXTERNAL, + external=mock_webhook_data, + ), + ) + + async def mock_handle_webhook(mock_webhook_data_decoded): + yield flow_input + + with patch.dict("sys.modules", {"app.extensions": mock_extension}): + from app.routers.v1 import plugin_webhook + + with patch("app.routers.v1.handle_webhook", mock_handle_webhook): + + result_status_code = await plugin_webhook(mock_request) + + assert result_status_code == 200 \ No newline at end of file From f2251a59e334bbc2dfd3327eb1fbd4ad661ec5b1 Mon Sep 17 00:00:00 2001 From: Sunandhita B Date: Fri, 13 Dec 2024 16:32:57 +0530 Subject: [PATCH 08/11] Added language tests to improve test coverage --- language/tests/test_audio_converter.py | 207 +++++++++++++++++++++++++ language/tests/test_crud.py | 89 +++++++++++ language/tests/test_incoming.py | 32 ++++ language/tests/test_outgoing.py | 29 ++++ 4 files changed, 357 insertions(+) create mode 100644 language/tests/test_audio_converter.py create mode 100644 language/tests/test_crud.py diff --git a/language/tests/test_audio_converter.py b/language/tests/test_audio_converter.py new file mode 100644 index 00000000..4078da36 --- /dev/null +++ b/language/tests/test_audio_converter.py @@ -0,0 +1,207 @@ +from io import BytesIO +from unittest.mock import AsyncMock, patch, MagicMock +from urllib.parse import ParseResult +import pytest +from src.audio_converter import _get_file_extension, _is_url, convert_to_wav, convert_to_wav_with_ffmpeg, convert_wav_bytes_to_mp3_bytes, get_filename_from_url + +@pytest.mark.parametrize( + "mock_string, parsed_url_string", + [("http://test.com/media/test_audio_file.mp3", + ParseResult(scheme='http', netloc='test.com', path='/media/test_audio_file.mp3', params='', query='', fragment='')), + ] +) +def test_is_url_with_valid_url(mock_string, parsed_url_string): + with patch("src.audio_converter.urlparse", return_value=parsed_url_string) as mock_url_parse: + result = _is_url(mock_string) + assert result == True + + mock_url_parse.assert_called_once_with(mock_string) + +@pytest.mark.parametrize( + "mock_string, parsed_url_string", + [("/media/test_audio_file.mp3", + ParseResult(scheme='', netloc='', path='/media/test_audio_file.mp3', params='', query='', fragment='')), + ] +) +def test_is_url_with_invalid_url(mock_string, parsed_url_string): + with patch("src.audio_converter.urlparse", return_value=parsed_url_string) as mock_url_parse: + + result = _is_url(mock_string) + + assert result == False + mock_url_parse.assert_called_once_with(mock_string) + +@pytest.mark.parametrize("mock_string", [("invalid string")]) +def test_is_url_failure(mock_string): + with patch("src.audio_converter.urlparse", side_effect = ValueError) as mock_url_parse: + + result = _is_url(mock_string) + + assert result == False + mock_url_parse.assert_called_once_with(mock_string) + +@pytest.mark.parametrize( + "mock_url, parsed_url", + [("http://test.com/media/test_audio_file.mp3", + ParseResult(scheme='http', netloc='test.com', path='/media/test_audio_file.mp3', params='', query='', fragment='')), + ] +) +def test_get_filename_from_url(mock_url, parsed_url): + with patch("src.audio_converter.urlparse", return_value=parsed_url) as mock_url_parse: + + result = get_filename_from_url(mock_url) + + assert result == "test_audio_file.mp3" + mock_url_parse.assert_called_once_with(mock_url) + +@pytest.mark.parametrize("mock_file_name_or_url", [("http://test.com/media/test_audio_file.mp3")]) +def test_get_file_extension_when_file_name_or_url_starts_with_http_or_https(mock_file_name_or_url): + + file_name = "test_audio_file.mp3" + + with patch("src.audio_converter.get_filename_from_url", return_value=file_name) as mock_get_filename_from_url: + + result = _get_file_extension(mock_file_name_or_url) + + assert result == "mp3" + mock_get_filename_from_url.assert_called_once_with(mock_file_name_or_url) + +@pytest.mark.parametrize("mock_file_name_or_url", [("/media/test_audio_file.mp3")]) +def test_get_file_extension_when_file_name_or_url_does_not_start_with_http_or_https(mock_file_name_or_url): + + result = _get_file_extension(mock_file_name_or_url) + + assert result == "mp3" + +@pytest.mark.parametrize("mock_source_url_or_file", [("http://test.com/media/test_audio_file.mp3")]) +def test_convert_to_wav_with_url(mock_source_url_or_file): + + with patch('src.audio_converter._get_file_extension') as mock_get_file_extension, \ + patch('src.audio_converter._is_url') as mock_is_url, \ + patch('src.audio_converter.httpx.get') as mock_httpx_get, \ + patch('pydub.AudioSegment.from_file') as mock_audio_segment: + + mock_get_file_extension.return_value = 'mp3' + mock_is_url.return_value = True + + mock_response = MagicMock() + mock_response.content = b"test audio content" + mock_httpx_get.return_value = mock_response + + mock_audio = MagicMock() + mock_audio.set_frame_rate.return_value = mock_audio + mock_audio.set_channels.return_value = mock_audio + mock_audio.export.return_value = None + mock_audio_segment.return_value = mock_audio + + result_wav_data = convert_to_wav(mock_source_url_or_file) + + assert isinstance(result_wav_data, bytes) + + mock_get_file_extension.assert_called_once_with(mock_source_url_or_file) + mock_is_url.assert_called_once_with(mock_source_url_or_file) + mock_httpx_get.assert_called_once_with(mock_source_url_or_file) + mock_audio_segment.assert_called_once() + +@pytest.mark.parametrize("mock_source_url_or_file", [("/media/test_audio_file.mp3")]) +def test_convert_to_wav_with_local_file(mock_source_url_or_file): + + with patch('src.audio_converter._get_file_extension') as mock_get_file_extension, \ + patch('src.audio_converter._is_url') as mock_is_url, patch('pydub.AudioSegment.from_file') as mock_audio_segment: + + mock_get_file_extension.return_value = 'mp3' + mock_is_url.return_value = False + + mock_audio = MagicMock() + mock_audio.set_frame_rate.return_value = mock_audio + mock_audio.set_channels.return_value = mock_audio + mock_audio.export.return_value = None + mock_audio_segment.return_value = mock_audio + + result_wav_data = convert_to_wav(mock_source_url_or_file) + + assert isinstance(result_wav_data, bytes) + + mock_get_file_extension.assert_called_once_with(mock_source_url_or_file) + mock_is_url.assert_called_once_with(mock_source_url_or_file) + mock_audio_segment.assert_called_once_with(mock_source_url_or_file, format="mp3") + +@pytest.mark.parametrize("mock_wav_bytes", [(b"some wav content")]) +def test_convert_wav_bytes_to_mp3_bytes(mock_wav_bytes): + + mock_wav_file = BytesIO(mock_wav_bytes) + + with patch('src.audio_converter.BytesIO', return_value = mock_wav_file) as mock_bytes_io, \ + patch('pydub.AudioSegment.from_file') as mock_audio_segment: + + mock_audio = MagicMock() + mock_audio.set_frame_rate.return_value = mock_audio + mock_audio.export.return_value = None + mock_audio_segment.return_value = mock_audio + + result = convert_wav_bytes_to_mp3_bytes(mock_wav_bytes) + + assert isinstance(result, bytes) + assert mock_bytes_io.call_count == 2 + + mock_bytes_io.assert_any_call(mock_wav_bytes) + mock_bytes_io.assert_called_with() + mock_audio_segment.assert_called_once_with(mock_wav_file, format="wav") + +@pytest.mark.asyncio +@pytest.mark.parametrize("mock_source_url_or_file", [("http://test.com/media/test_audio_file.mp3")]) +async def test_convert_to_wav_with_ffmpeg_for_valid_url(mock_source_url_or_file): + + with patch('src.audio_converter._get_file_extension', return_value = "mp3") as mock_get_file_extension, \ + patch('src.audio_converter._is_url', return_value = True) as mock_is_url, \ + patch('src.audio_converter.httpx.get') as mock_httpx_get, \ + patch('src.audio_converter.subprocess.run') as mock_subprocess_run, \ + patch('src.audio_converter.aiofiles.open') as mock_aiofiles_open, \ + patch('src.audio_converter.aiofiles.os.remove', return_value = None) as mock_aiofiles_os_remove: + + mock_response = MagicMock() + mock_response.content = b"test audio content" + mock_httpx_get.return_value = mock_response + + mock_subprocess_run.return_value = None + + mock_wav_file = AsyncMock() + mock_wav_file.read.return_value = b"test audio data" + mock_aiofiles_open.return_value.__aenter__.return_value = mock_wav_file + + result_audio_data = await convert_to_wav_with_ffmpeg(mock_source_url_or_file) + + assert isinstance(result_audio_data, bytes) + + mock_get_file_extension.assert_called_once_with(mock_source_url_or_file) + mock_is_url.assert_called_once_with(mock_source_url_or_file) + mock_httpx_get.assert_called_once_with(mock_source_url_or_file) + mock_subprocess_run.assert_called_once() + mock_aiofiles_open.assert_called_once() + mock_aiofiles_os_remove.assert_awaited_once() + +@pytest.mark.asyncio +@pytest.mark.parametrize("mock_source_url_or_file", [("/media/test_audio_file.mp3")]) +async def test_convert_to_wav_with_ffmpeg_for_local_file(mock_source_url_or_file): + + with patch('src.audio_converter._get_file_extension', return_value = "mp3") as mock_get_file_extension, \ + patch('src.audio_converter._is_url', return_value = False) as mock_is_url, \ + patch('src.audio_converter.subprocess.run') as mock_subprocess_run, \ + patch('src.audio_converter.aiofiles.open') as mock_aiofiles_open, \ + patch('src.audio_converter.aiofiles.os.remove', return_value = None) as mock_aiofiles_os_remove: + + mock_subprocess_run.return_value = None + + mock_wav_file = AsyncMock() + mock_wav_file.read.return_value = b"test audio data" + mock_aiofiles_open.return_value.__aenter__.return_value = mock_wav_file + + result_audio_data = await convert_to_wav_with_ffmpeg(mock_source_url_or_file) + + assert isinstance(result_audio_data, bytes) + + mock_get_file_extension.assert_called_once_with(mock_source_url_or_file) + mock_is_url.assert_called_once_with(mock_source_url_or_file) + mock_subprocess_run.assert_called_once() + mock_aiofiles_open.assert_called_once() + mock_aiofiles_os_remove.assert_awaited_once() diff --git a/language/tests/test_crud.py b/language/tests/test_crud.py new file mode 100644 index 00000000..e7b2dd71 --- /dev/null +++ b/language/tests/test_crud.py @@ -0,0 +1,89 @@ +from unittest import mock +from unittest.mock import AsyncMock, patch, MagicMock +import pytest + +from src.crud import get_user_preferred_language, get_user_preferred_language_by_pid +from lib.db_session_handler import DBSessionHandler + +class AsyncContextManagerMock: + def __init__(self, session_mock): + self.session_mock = session_mock + + async def __aenter__(self): + return self.session_mock + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + +class AsyncBeginMock: + async def __aenter__(self): + pass + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + +@pytest.mark.asyncio +async def test_get_user_preferred_language_when_pid_not_none(): + + turn_id = "test_turn_id" + pid = "test_user_id" + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.first.return_value = pid + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + with patch("src.crud.get_user_preferred_language_by_pid", return_value = "en") as mock_get_user_preferred_language_by_pid: + + result = await get_user_preferred_language(turn_id) + + assert result == "en" + + mock_get_user_preferred_language_by_pid.assert_awaited_once_with(pid) + +@pytest.mark.asyncio +async def test_get_user_preferred_language_when_pid_is_none(): + + turn_id = "test_turn_id" + pid = None + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.first.return_value = pid + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + with patch("src.crud.get_user_preferred_language_by_pid") as mock_get_user_preferred_language_by_pid: + + result = await get_user_preferred_language(turn_id) + + assert result == None + + mock_get_user_preferred_language_by_pid.assert_not_awaited() + +@pytest.mark.asyncio +async def test_get_user_preferred_language_by_pid(): + + pid = "test_user_id" + language_preference = "en" + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.first.return_value = language_preference + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await get_user_preferred_language_by_pid(pid) + + assert result == "en" \ No newline at end of file diff --git a/language/tests/test_incoming.py b/language/tests/test_incoming.py index 0f44d5dc..096d911c 100644 --- a/language/tests/test_incoming.py +++ b/language/tests/test_incoming.py @@ -109,3 +109,35 @@ async def test_handle_input_audio_message(): assert result.user_input.message.message_type == MessageType.TEXT assert result.user_input.message.text is not None assert result.user_input.message.text.body == "Translated audio message" + +@pytest.mark.asyncio +async def test_handle_input_when_message_type_is_text_but_no_text_content_provided(): + mock_extension.reset_mock() + + preferred_language = LanguageCodes.EN + turn_id = "turn1" + + message = MagicMock() + message.message_type.return_value = MessageType.TEXT + message.text.return_value = None + + with pytest.raises(ValueError) as exception_info: + await handle_input(turn_id, preferred_language, message) + + assert exception_info.value == "Message text is empty" + +@pytest.mark.asyncio +async def test_handle_input_when_message_type_is_audio_but_no_audio_content_provided(): + mock_extension.reset_mock() + + preferred_language = LanguageCodes.EN + turn_id = "turn1" + + message = MagicMock() + message.message_type.return_value = MessageType.AUDIO + message.audio.return_value = None + + with pytest.raises(ValueError) as exception_info: + await handle_input(turn_id, preferred_language, message) + + assert exception_info.value == "Message audio is empty" \ No newline at end of file diff --git a/language/tests/test_outgoing.py b/language/tests/test_outgoing.py index 80970ed7..c81313dd 100644 --- a/language/tests/test_outgoing.py +++ b/language/tests/test_outgoing.py @@ -164,3 +164,32 @@ async def test_handle_output_button_message(): assert result[1].bot_output.message_type == MessageType.AUDIO assert result[1].bot_output.audio is not None assert result[1].bot_output.audio.media_url == "https://storage.url/test_audio.ogg" + + +@pytest.mark.asyncio +async def test_handle_output_text_message_with_header_and_footer(): + mock_extension.reset_mock() + turn_id = "turn1" + preferred_language = LanguageCodes.EN + message = Message( + message_type=MessageType.TEXT, + text=TextMessage(header="test_header",body="hello",footer="test_footer"), + ) + result = await handle_output(turn_id, preferred_language, message) + + assert isinstance(result, list) + assert len(result) == 2 + assert isinstance(result[0], Channel) + assert result[0].turn_id == turn_id + assert result[0].bot_output is not None + assert result[0].bot_output.message_type == MessageType.TEXT + assert result[0].bot_output.text is not None + assert result[0].bot_output.text.header == "translated_test_header" + assert result[0].bot_output.text.body == "translated_hello" + assert result[0].bot_output.text.footer == "translated_test_footer" + assert isinstance(result[1], Channel) + assert result[1].turn_id == turn_id + assert result[1].bot_output is not None + assert result[1].bot_output.message_type == MessageType.AUDIO + assert result[1].bot_output.audio is not None + assert result[1].bot_output.audio.media_url == "https://storage.url/test_audio.ogg" \ No newline at end of file From 9e141cc3d856d2809bcf8f259c3753a4cf076146 Mon Sep 17 00:00:00 2001 From: Ananya Agrawal Date: Mon, 16 Dec 2024 15:30:45 +0530 Subject: [PATCH 09/11] Added test cases in Indexer to improve test coverage --- indexer/tests/test_indexer.py | 141 +++++++++++++++++++++++++++++++++- indexer/tests/test_model.py | 20 +++++ 2 files changed, 160 insertions(+), 1 deletion(-) create mode 100644 indexer/tests/test_model.py diff --git a/indexer/tests/test_indexer.py b/indexer/tests/test_indexer.py index 665ce36b..08a9687a 100644 --- a/indexer/tests/test_indexer.py +++ b/indexer/tests/test_indexer.py @@ -2,9 +2,11 @@ import os from contextlib import asynccontextmanager from unittest.mock import ANY, AsyncMock, MagicMock, mock_open, patch - +from r2r import ChunkingConfig, R2RBuilder, R2RConfig +import asyncpg import pytest from lib.data_models import Indexer +from model import InternalServerException # Mock environment variables @@ -38,6 +40,39 @@ async def mock_read_file(file_path, mode): ): import indexing + def test_parse_file(): + # Mock parser functions + with patch("indexing.pdf_parser", return_value="PDF Content") as mock_pdf_parser, \ + patch("indexing.docx_parser", return_value="DOCX Content") as mock_docx_parser, \ + patch("indexing.xlsx_parser", return_value="XLSX Content") as mock_xlsx_parser, \ + patch("indexing.json_parser", return_value="JSON Content") as mock_json_parser, \ + patch("indexing.default_parser", return_value="Default Content") as mock_default_parser: + + # Test PDF file + result = indexing.parse_file("test.pdf") + assert result == "PDF Content" + mock_pdf_parser.assert_called_once_with("test.pdf") + + # Test DOCX file + result = indexing.parse_file("test.docx") + assert result == "DOCX Content" + mock_docx_parser.assert_called_once_with("test.docx") + + # Test XLSX file + result = indexing.parse_file("test.xlsx") + assert result == "XLSX Content" + mock_xlsx_parser.assert_called_once_with("test.xlsx") + + # Test JSON file + result = indexing.parse_file("test.json") + assert result == "JSON Content" + mock_json_parser.assert_called_once_with("test.json") + + # Test unsupported file extension (uses default parser) + result = indexing.parse_file("test.txt") + assert result == "Default Content" + mock_default_parser.assert_called_once_with("test.txt") + def test_docx_parser(): with patch("docx2txt.process", return_value="DOCX Content") as mock_process: result = indexing.docx_parser("test.docx") @@ -80,6 +115,28 @@ async def test_text_converter_textify(): result = await converter.textify("test.txt") assert result == "Parsed Content" + @pytest.mark.asyncio + async def test_create_pg_vector_index_if_not_exists(): + # Mock the asyncpg connection and its methods + mock_connection = MagicMock(spec=asyncpg.connection.Connection) + + with patch("asyncpg.connect", return_value=mock_connection): + # Initialize the DataIndexer instance + indexer = indexing.DataIndexer() + + # Call the method + await indexer.create_pg_vector_index_if_not_exists() + + # Assertions + mock_connection.transaction.assert_called_once() + mock_connection.execute.assert_any_call( + "ALTER TABLE langchain_pg_embedding ALTER COLUMN embedding TYPE vector(1536)" + ) + mock_connection.execute.assert_any_call( + "CREATE INDEX IF NOT EXISTS langchain_embeddings_hnsw ON langchain_pg_embedding USING hnsw (embedding vector_cosine_ops)" + ) + mock_connection.close.assert_called_once() + # Test DataIndexer.index method for default input @pytest.mark.asyncio async def test_default_data_indexer(): @@ -105,6 +162,33 @@ async def test_default_data_indexer(): await indexer.index(indexer_input) mock_pg_vector.assert_called_once() indexer.create_pg_vector_index_if_not_exists.assert_awaited_once() + + @pytest.mark.asyncio + async def test_get_r2r(): + indexer = indexing.DataIndexer() + chunk_size = 4000 + chunk_overlap = 200 + mock_r2r_app = MagicMock() + + with patch("indexing.R2RConfig") as MockR2RConfig, patch("indexing.R2RBuilder") as MockR2RBuilder: + mock_r2r_builder = MagicMock() + MockR2RBuilder.return_value = mock_r2r_builder + mock_r2r_builder.build.return_value = mock_r2r_app + + r2r_app = await indexer.get_r2r(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + + MockR2RConfig.assert_called_once_with( + config_data={ + "chunking": ChunkingConfig( + chunk_size=chunk_size, chunk_overlap=chunk_overlap + ), + } + ) + + MockR2RBuilder.assert_called_once_with(config=MockR2RConfig.return_value) + mock_r2r_builder.build.assert_called_once() + + assert r2r_app == mock_r2r_app # Test DataIndexer.index method for r2r input @pytest.mark.asyncio @@ -124,3 +208,58 @@ async def test_r2r_data_indexer(): await indexer.index(indexer_input) mock_r2r_app.engine.aingest_files.assert_awaited_once_with(files=[ANY]) assert os.environ["POSTGRES_VECS_COLLECTION"] == "test_collection" + + @pytest.mark.asyncio + async def test_get_embeddings_azure(): + # Mock environment variables for Azure setup + os.environ["OPENAI_API_TYPE"] = "azure" + os.environ["AZURE_EMBEDDING_MODEL_NAME"] = "azure-model-name" + os.environ["AZURE_DEPLOYMENT_NAME"] = "azure-deployment" + os.environ["AZURE_OPENAI_ENDPOINT"] = "https://azure-endpoint" + os.environ["AZURE_OPENAI_API_KEY"] = "azure-api-key" + + indexer = indexing.DataIndexer() + + # Mock the AzureOpenAIEmbeddings class and its constructor + with patch("indexing.AzureOpenAIEmbeddings") as mock_azure_embeddings: + mock_instance = MagicMock() + mock_azure_embeddings.return_value = mock_instance + + # Call the get_embeddings method + embeddings = await indexer.get_embeddings() + + # Assert that AzureOpenAIEmbeddings was created with the correct parameters + mock_azure_embeddings.assert_called_once_with( + model="azure-model-name", + dimensions=1536, + azure_deployment="azure-deployment", + azure_endpoint="https://azure-endpoint", + openai_api_type="azure", + openai_api_key="azure-api-key" + ) + + # Assert the returned value is the mocked instance of AzureOpenAIEmbeddings + assert embeddings == mock_instance + + + @pytest.mark.asyncio + async def test_get_embeddings_openai(): + # Mock environment variables for OpenAI setup + os.environ["OPENAI_API_TYPE"] = "openai" + + indexer = indexing.DataIndexer() + + # Mock the OpenAIEmbeddings class and its constructor + with patch("indexing.OpenAIEmbeddings") as mock_openai_embeddings: + mock_instance = MagicMock() + mock_openai_embeddings.return_value = mock_instance + + # Call the get_embeddings method + embeddings = await indexer.get_embeddings() + + # Assert that OpenAIEmbeddings was created with the correct parameters + mock_openai_embeddings.assert_called_once_with(client="") + + # Assert the returned value is the mocked instance of OpenAIEmbeddings + assert embeddings == mock_instance + diff --git a/indexer/tests/test_model.py b/indexer/tests/test_model.py new file mode 100644 index 00000000..f2c1c4a6 --- /dev/null +++ b/indexer/tests/test_model.py @@ -0,0 +1,20 @@ +from model import InternalServerException, ServiceUnavailableException + +# Test for InternalServerException +def test_internal_server_exception(): + message = "An internal server error occurred" + exception = InternalServerException(message) + + assert exception.message == message + assert exception.status_code == 500 + assert str(exception) == message + + +# Test for ServiceUnavailableException +def test_service_unavailable_exception(): + message = "Service is unavailable" + exception = ServiceUnavailableException(message) + + assert exception.message == message + assert exception.status_code == 503 + assert str(exception) == message From 441160f79d9b55f6f875a35407d6a90eca5b4142 Mon Sep 17 00:00:00 2001 From: Sunandhita B Date: Wed, 18 Dec 2024 17:42:26 +0530 Subject: [PATCH 10/11] Added flow tests to improve test coverage --- flow/tests/test_bot_input.py | 445 ++++++++++++++++++++++++++++++++- flow/tests/test_bot_install.py | 99 ++++++++ flow/tests/test_crud.py | 426 +++++++++++++++++++++++++++++++ 3 files changed, 969 insertions(+), 1 deletion(-) create mode 100644 flow/tests/test_crud.py diff --git a/flow/tests/test_bot_input.py b/flow/tests/test_bot_input.py index 35c4e054..db663002 100644 --- a/flow/tests/test_bot_input.py +++ b/flow/tests/test_bot_input.py @@ -2,7 +2,8 @@ from unittest.mock import MagicMock, patch, AsyncMock import json import pytest -from lib.models import JBBot +from datetime import datetime, timedelta +from lib.models import JBBot, JBFSMState, JBSession from lib.data_models import ( Flow, FlowIntent, @@ -13,6 +14,7 @@ Language, LanguageIntent, FSMOutput, + FSMInput, FSMIntent, MessageType, Message, @@ -34,6 +36,8 @@ RAGQuery, ) +mock_extension = MagicMock() + mock_bot = JBBot( id="test_bot_id", name="test_bot_name", @@ -435,3 +439,442 @@ async def test_handle_flow_input( "zerotwo", {"state": "test_state", "variables": {"test_key": "test_value"}}, ) + +@pytest.mark.asyncio +@patch.dict("sys.modules", {"src.extensions": mock_extension}) +async def test_handle_flow_input_when_flow_intent_is_dialog_and_dialog_not_found(): + + mock_flow_input = MagicMock(spec = Flow) + mock_flow_input.intent = FlowIntent.DIALOG + mock_flow_input.dialog = None + + from src.handlers.flow_input import handle_flow_input + + with patch("src.handlers.flow_input.logger") as mock_logger: + result = await handle_flow_input(mock_flow_input) + + assert result is None + + mock_logger.error.asset_called_once_with("Dialog not found in flow input") + +@pytest.mark.asyncio +@patch.dict("sys.modules", {"src.extensions": mock_extension}) +async def test_handle_flow_input_when_flow_intent_is_callback_and_callback_not_found(): + + mock_flow_input = MagicMock(spec = Flow) + mock_flow_input.intent = FlowIntent.CALLBACK + mock_flow_input.callback = None + + from src.handlers.flow_input import handle_flow_input + + with patch("src.handlers.flow_input.logger") as mock_logger: + result = await handle_flow_input(mock_flow_input) + + assert result is None + + mock_logger.error.asset_called_once_with("Callback not found in flow input") + +@pytest.mark.asyncio +@patch.dict("sys.modules", {"src.extensions": mock_extension}) +async def test_handle_flow_input_when_flow_intent_is_user_input_and_user_input_not_found(): + + mock_flow_input = MagicMock(spec = Flow) + mock_flow_input.intent = FlowIntent.USER_INPUT + mock_flow_input.user_input = None + + from src.handlers.flow_input import handle_flow_input + + with patch("src.handlers.flow_input.logger") as mock_logger: + result = await handle_flow_input(mock_flow_input) + + assert result is None + + mock_logger.error.asset_called_once_with("User input not found in flow input") + +@pytest.mark.asyncio +@patch.dict("sys.modules", {"src.extensions": mock_extension}) +async def test_handle_flow_input_for_invalid_flow_intent(): + + mock_flow_input = MagicMock(spec = Flow) + mock_flow_input.intent = "test_invalid_flow_intent" + + from src.handlers.flow_input import handle_flow_input + + with patch("src.handlers.flow_input.logger") as mock_logger: + result = await handle_flow_input(mock_flow_input) + + assert result is None + + mock_logger.error.asset_called_once_with("Invalid flow intent: test_invalid_flow_intent") + +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +def test_handle_bot_output_when_message_not_found_in_fsm_output_for_send_message_fsm_intent(): + + mock_produce_message.reset_mock() + + turn_id = "test_turn_id" + mock_fsm_output = MagicMock(spec = FSMOutput) + mock_fsm_output.intent = FSMIntent.SEND_MESSAGE + mock_fsm_output.message = None + + with patch("src.handlers.bot_input.logger") as mock_logger: + from src.handlers.bot_input import handle_bot_output + + result = handle_bot_output(mock_fsm_output, turn_id) + + assert result is None + + mock_logger.error.asset_called_once_with("Message not found in fsm output") + +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +def test_handle_bot_output_when_rag_query_not_found_in_fsm_output_for_fsm_intent_rag_call(): + + mock_produce_message.reset_mock() + + turn_id = "test_turn_id" + mock_fsm_output = MagicMock(spec = FSMOutput) + mock_fsm_output.intent = FSMIntent.RAG_CALL + mock_fsm_output.rag_query = None + + with patch("src.handlers.bot_input.logger") as mock_logger: + from src.handlers.bot_input import handle_bot_output + + result = handle_bot_output(mock_fsm_output, turn_id) + + assert result is None + + mock_logger.error.asset_called_once_with("RAG query not found in fsm output") + +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +def test_handle_bot_output_for_invalid_intent_in_fsm_output(): + + mock_produce_message.reset_mock() + + turn_id = "test_turn_id" + mock_fsm_output = MagicMock(spec = FSMOutput) + mock_fsm_output.intent = "test_invalid_intent" + + with patch("src.handlers.bot_input.logger") as mock_logger: + from src.handlers.bot_input import handle_bot_output + + result = handle_bot_output(mock_fsm_output, turn_id) + + assert result is NotImplemented + + mock_logger.error.asset_called_once_with("Invalid intent in fsm output") + +@pytest.mark.asyncio +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +async def test_manage_session_for_new_session(): + + mock_produce_message.reset_mock() + + turn_id = "test_turn_id" + new_session = True + + mock_session = JBSession(id="test_session_id", user_id="test_user_id1", channel_id="test_channel_id1") + + with patch("src.handlers.bot_input.create_session", return_value = mock_session) as mock_create_session: + from src.handlers.bot_input import manage_session + + result_session = await manage_session(turn_id, new_session) + + assert result_session is not None + assert isinstance(result_session, JBSession) + assert result_session.id == mock_session.id + assert result_session.user_id == mock_session.user_id + assert result_session.channel_id == mock_session.channel_id + + mock_create_session.assert_awaited_once_with(turn_id) + +@pytest.mark.asyncio +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +async def test_manage_session_when_session_not_found(): + + mock_produce_message.reset_mock() + + turn_id = "test_turn_id" + new_session = False + + mock_session = JBSession(id="test_session_id", + user_id="test_user_id1", + channel_id="test_channel_id1") + + with patch("src.handlers.bot_input.get_session_by_turn_id", return_value = None) as mock_get_session_by_turn_id, \ + patch("src.handlers.bot_input.create_session", return_value = mock_session) as mock_create_session: + from src.handlers.bot_input import manage_session + + result_session = await manage_session(turn_id, new_session) + + assert result_session is not None + assert isinstance(result_session, JBSession) + assert result_session.id == mock_session.id + assert result_session.user_id == mock_session.user_id + assert result_session.channel_id == mock_session.channel_id + + mock_get_session_by_turn_id.assert_awaited_once_with(turn_id) + mock_create_session.assert_awaited_once_with(turn_id) + +@pytest.mark.asyncio +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +async def test_manage_session_when_session_updating(): + + mock_produce_message.reset_mock() + + turn_id = "test_turn_id" + new_session = False + + mock_session = JBSession(id="test_session_id", + user_id="test_user_id", + channel_id="test_channel_id", + updated_at = datetime.now() - timedelta(hours=1)) + + with patch("src.handlers.bot_input.get_session_by_turn_id", return_value = mock_session) as mock_get_session_by_turn_id, \ + patch("src.handlers.bot_input.update_session", return_value = None) as mock_update_session, \ + patch("src.handlers.bot_input.update_turn", return_value = None) as mock_update_turn: + + from src.handlers.bot_input import manage_session + + result_session = await manage_session(turn_id, new_session) + + assert result_session is not None + assert isinstance(result_session, JBSession) + assert result_session.id == mock_session.id + assert result_session.user_id == mock_session.user_id + assert result_session.channel_id == mock_session.channel_id + + mock_get_session_by_turn_id.assert_awaited_once_with(turn_id) + mock_update_session.assert_awaited_once_with(mock_session.id) + mock_update_turn.assert_awaited_once_with(session_id=mock_session.id, turn_id=turn_id) + +@pytest.mark.asyncio +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +async def test_handle_bot_input_when_bot_not_found_for_given_session_id(): + + mock_produce_message.reset_mock() + + mock_state = MagicMock() + + mock_fsm_input = MagicMock(spec= FSMInput) + mock_session_id = "test_session_id" + + with patch("src.handlers.bot_input.get_state_by_session_id", return_value = mock_state) as mock_get_state_by_session_id, \ + patch("src.handlers.bot_input.get_bot_by_session_id", return_value = None) as mock_get_bot_by_session_id, \ + patch("src.handlers.bot_input.logger") as mock_logger: + + from src.handlers.bot_input import handle_bot_input + + result = [item async for item in handle_bot_input(mock_fsm_input, mock_session_id)] + + assert result==[] + + mock_get_state_by_session_id.assert_awaited_once_with(mock_session_id) + mock_get_bot_by_session_id.assert_awaited_once_with(mock_session_id) + mock_logger.error.asset_called_once_with("Bot not found for session_id: test_session_id") + +@pytest.mark.asyncio +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +async def test_handle_bot_input_when_error_in_running_fsm(): + + mock_produce_message.reset_mock() + + mock_fsm_input = MagicMock(spec= FSMInput) + mock_fsm_input.model_dump.return_value = mock_fsm_input + + mock_session_id = "test_session_id" + + mock_state = MagicMock(spec= JBFSMState) + mock_state.variables={"test_key": "test_value"} + + mock_completed_process = MagicMock() + mock_completed_process.stderr = "test_error_message" + + with patch("src.handlers.bot_input.get_state_by_session_id", return_value = mock_state) as mock_get_state_by_session_id, \ + patch("src.handlers.bot_input.get_bot_by_session_id", return_value = mock_bot) as mock_get_bot_by_session_id, \ + patch("subprocess.run", return_value = mock_completed_process), \ + patch("json.dumps", return_value = MagicMock()), \ + patch("src.handlers.bot_input.logger") as mock_logger: + + from src.handlers.bot_input import handle_bot_input + + result = [item async for item in handle_bot_input(mock_fsm_input, mock_session_id)] + + assert result==[] + + mock_get_state_by_session_id.assert_awaited_once_with(mock_session_id) + mock_get_bot_by_session_id.assert_awaited_once_with(mock_session_id) + mock_logger.error.asset_called_once_with("Error while running fsm: test_error_message") + +@pytest.mark.asyncio +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +async def test_handle_user_input_when_text_not_found(): + + mock_produce_message.reset_mock() + + mock_user_input = MagicMock(spec = UserInput) + mock_user_input.turn_id = "test_turn_id" + + mock_user_input.message = MagicMock(spec = Message) + mock_user_input.message.message_type = MessageType.TEXT + mock_user_input.message.text = None + + with patch("src.handlers.bot_input.logger") as mock_logger: + from src.handlers.bot_input import handle_user_input + + result = await handle_user_input(mock_user_input) + + assert result is None + + mock_logger.error.asset_called_once_with("Text not found in user input") + +@pytest.mark.asyncio +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +async def test_handle_user_input_when_interactive_reply_not_found(): + + mock_produce_message.reset_mock() + + mock_user_input = MagicMock(spec = UserInput) + mock_user_input.turn_id = "test_turn_id" + + mock_user_input.message = MagicMock(spec = Message) + mock_user_input.message.message_type = MessageType.INTERACTIVE_REPLY + mock_user_input.message.interactive_reply = None + + with patch("src.handlers.bot_input.logger") as mock_logger: + from src.handlers.bot_input import handle_user_input + + result = await handle_user_input(mock_user_input) + + assert result is None + + mock_logger.error.asset_called_once_with("Interactive reply not found in user input") + +@pytest.mark.asyncio +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +async def test_handle_user_input_when_form_reply_not_found(): + + mock_produce_message.reset_mock() + + mock_user_input = MagicMock(spec = UserInput) + mock_user_input.turn_id = "test_turn_id" + + mock_user_input.message = MagicMock(spec = Message) + mock_user_input.message.message_type = MessageType.FORM_REPLY + mock_user_input.message.form_reply = None + + with patch("src.handlers.bot_input.logger") as mock_logger: + from src.handlers.bot_input import handle_user_input + + result = await handle_user_input(mock_user_input) + + assert result is None + + mock_logger.error.asset_called_once_with("Form reply not found in user input") + +@pytest.mark.asyncio +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +async def test_handle_user_input_when_message_type_not_implemented_or_invalid(): + + mock_produce_message.reset_mock() + + mock_user_input = MagicMock(spec = UserInput) + mock_user_input.turn_id = "test_turn_id" + + mock_user_input.message = MagicMock(spec = Message) + mock_user_input.message.message_type = "test_not_implemented_message_type" + + from src.handlers.bot_input import handle_user_input + + result = await handle_user_input(mock_user_input) + + assert result is NotImplemented + +@pytest.mark.asyncio +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +async def test_handle_callback_input_when_rag_response_not_found(): + + mock_produce_message.reset_mock() + + mock_callback = MagicMock(spec = Callback) + mock_callback.turn_id = "test_turn_id" + mock_callback.callback_type = CallbackType.RAG + mock_callback.rag_response = None + + with patch("src.handlers.bot_input.logger") as mock_logger: + from src.handlers.bot_input import handle_callback_input + + result = await handle_callback_input(mock_callback) + + assert result is None + + mock_logger.error.asset_called_once_with("RAG response not found in callback input") + +@pytest.mark.asyncio +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +async def test_handle_dialog_input_when_message_not_found(): + + mock_produce_message.reset_mock() + + mock_dialog = MagicMock(spec = Dialog) + mock_dialog.turn_id = "test_turn_id" + + mock_dialog.message = MagicMock(spec = Message) + mock_dialog.message.dialog = None + + with patch("src.handlers.bot_input.logger") as mock_logger: + from src.handlers.bot_input import handle_dialog_input + + result = await handle_dialog_input(mock_dialog) + + assert result is None + + mock_logger.error.asset_called_once_with("Message not found in dialog input") + +@pytest.mark.asyncio +@patch.dict( + "sys.modules", {"src.extensions": MagicMock(produce_message=mock_produce_message)} +) +async def test_handle_dialog_input_when_dialog_option_not_implemented(): + + mock_produce_message.reset_mock() + + mock_dialog = MagicMock(spec = Dialog) + mock_dialog.turn_id = "test_turn_id" + mock_dialog.message = MagicMock(spec = Message) + + mock_dialog.message.dialog = MagicMock(spec = DialogMessage) + mock_dialog.message.dialog.dialog_id = "test_not_implemented_dialog_option" + + from src.handlers.bot_input import handle_dialog_input + + result = await handle_dialog_input(mock_dialog) + + assert result is NotImplemented \ No newline at end of file diff --git a/flow/tests/test_bot_install.py b/flow/tests/test_bot_install.py index c4e7594f..de306d58 100644 --- a/flow/tests/test_bot_install.py +++ b/flow/tests/test_bot_install.py @@ -1,3 +1,6 @@ +import os +from pathlib import Path +import shutil from unittest.mock import MagicMock, patch, AsyncMock import pytest from lib.data_models import BotConfig, Flow, Bot, FlowIntent, BotIntent @@ -67,3 +70,99 @@ async def test_handle_bot_delete(mock_delete_bot): await handle_flow_input(flow_input) mock_delete_bot.assert_awaited_once_with("test_bot_id") + +@pytest.mark.asyncio +@patch.dict("sys.modules", {"src.extensions": mock_extension}) +async def test_handle_bot_install_with_bot_config_missing_bot(): + + mock_flow_input = MagicMock(spec = Flow) + + mock_flow_input.intent = FlowIntent.BOT + mock_flow_input.bot_config = MagicMock(spec = BotConfig) + mock_flow_input.bot_config.intent = BotIntent.INSTALL + mock_flow_input.bot_config.bot = None + + from src.handlers.flow_input import handle_flow_input + + with patch("src.handlers.flow_input.logger") as mock_logger: + result = await handle_flow_input(mock_flow_input) + + assert result is None + + mock_logger.error.asset_called_once_with("Bot config missing bot") + +@pytest.mark.asyncio +@patch.dict("sys.modules", {"src.extensions": mock_extension}) +async def test_handle_bot_with_invalid_bot_config_intent(): + + mock_flow_input = MagicMock(spec = Flow) + + mock_flow_input.intent = FlowIntent.BOT + mock_flow_input.bot_config = MagicMock(spec = BotConfig) + mock_flow_input.bot_config.intent = "invalid_bot_config_intent" + + from src.handlers.flow_input import handle_flow_input + + with patch("src.handlers.flow_input.logger") as mock_logger: + result = await handle_flow_input(mock_flow_input) + + assert result is None + + mock_logger.error.asset_called_once_with("Invalid intent in bot config") + +@pytest.mark.asyncio +async def test_install_bot(): + + bot_id = "test_bot_id" + bot_fsm_code = "test_bot_fsm_code" + bot_requirements_txt = "test_requirements" + index_urls = {"index_url_1":"test_url_1","index_url_2":"test_url_2"} + + mock_install_command = MagicMock(spec=list) + mock_install_command.extend.return_value = None + + from src.handlers.bot_install import install_bot + + with patch.object(Path, "mkdir", return_value = AsyncMock()), \ + patch.object(shutil, "copytree", return_value = None) as mock_copytree, \ + patch.object(shutil, "copy2", return_value = None) as mock_copy2, \ + patch.object(Path, "write_text", return_value = AsyncMock()),\ + patch('src.handlers.bot_install.subprocess.run', return_value = None) as mock_subprocess_run: + + result = await install_bot(bot_id, bot_fsm_code, bot_requirements_txt, index_urls) + + assert result is None + assert mock_subprocess_run.call_count == 2 + mock_copytree.assert_not_called() + mock_copy2.assert_called_once() + +@pytest.mark.asyncio +async def test_delete_bot(): + + bot_id = "test_bot_id" + + with patch.object(shutil, "rmtree", return_value = None) as mock_rmtree: + from src.handlers.bot_install import delete_bot + + await delete_bot(bot_id) + + mock_rmtree.assert_called_once() + + + +@pytest.mark.asyncio +@patch.dict("sys.modules", {"src.extensions": mock_extension}) +async def test_handle_flow_input_when_flow_intent_is_bot_and_bot_config_not_found(): + + mock_flow_input = MagicMock(spec = Flow) + mock_flow_input.intent = FlowIntent.BOT + mock_flow_input.bot_config = None + + from src.handlers.flow_input import handle_flow_input + + with patch("src.handlers.flow_input.logger") as mock_logger: + result = await handle_flow_input(mock_flow_input) + + assert result is None + + mock_logger.error.asset_called_once_with("Bot config not found in flow input") \ No newline at end of file diff --git a/flow/tests/test_crud.py b/flow/tests/test_crud.py new file mode 100644 index 00000000..16643e1c --- /dev/null +++ b/flow/tests/test_crud.py @@ -0,0 +1,426 @@ +import json +import pytest +from unittest import mock +from lib.db_session_handler import DBSessionHandler +from lib.models import JBBot, JBFSMState, JBSession, JBTurn +from lib.data_models import ( + MessageType, + Message, + TextMessage, +) + +from src.crud import ( + create_bot, + create_message, + create_session, + get_all_bots, + get_bot_by_session_id, + get_session_by_turn_id, + get_state_by_session_id, + insert_jb_webhook_reference, + insert_state, + update_session, + update_state_and_variables, + update_turn, + update_user_language +) + +class AsyncContextManagerMock: + def __init__(self, session_mock): + self.session_mock = session_mock + + async def __aenter__(self): + return self.session_mock + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + +class AsyncBeginMock: + async def __aenter__(self): + pass + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + +@pytest.mark.asyncio +async def test_get_state_by_session_id(): + + session_id = "test_session_id" + + mock_state = JBFSMState( + id = "test_id", + session_id = "test_session_id", + state = "test_state", + variables = {"var1":"test_variable_1","var2":"test_variable_2"}, + message = "test_message", + ) + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.first.return_value = mock_state + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await get_state_by_session_id(session_id) + + assert isinstance(result,JBFSMState) + assert result.id == mock_state.id + assert result.session_id == mock_state.session_id + assert result.state == mock_state.state + assert result.variables == mock_state.variables + assert result.message == mock_state.message + +@pytest.mark.asyncio +async def test_create_session(): + + turn_id = "test_turn_id" + session_id = "test_session_id" + + mock_jb_turn = JBTurn( + id = "test_turn_id", + session_id = "test_session_id", + bot_id = "test_bot_id", + channel_id = "test_channel_id", + user_id = "test_user_id", + turn_type = "test_turn_type" + ) + + mock_session_object = JBSession( + id = "test_session_id", + user_id = "test_user_id", + channel_id = "test_channel_id" + ) + + mock_session = mock.Mock() + mock_session.commit = mock.AsyncMock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.first.return_value = mock_jb_turn + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + with mock.patch("src.crud.uuid.uuid4", return_value = session_id): + result = await create_session(turn_id) + + assert result is not None + assert isinstance(result, JBSession) + assert result.id == mock_session_object.id + assert result.user_id == mock_session_object.user_id + assert result.channel_id == mock_session_object.channel_id + + assert mock_session.commit.call_count == 2 + assert mock_session.begin.call_count == 3 + +@pytest.mark.asyncio +async def test_update_session(): + + session_id = "test_session_id" + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_session.commit = mock.AsyncMock() + + mock_execute = mock.AsyncMock() + mock_session.execute = mock_execute + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await update_session(session_id) + + assert result is None + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_update_turn(): + + session_id = "test_session_id" + turn_id = "test_turn_id" + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_session.commit = mock.AsyncMock() + + mock_execute = mock.AsyncMock() + mock_session.execute = mock_execute + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await update_turn(session_id, turn_id) + + assert result is None + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_insert_state(): + + session_id = "test_session_id" + state = "test_state" + variables = {"var1":"test_variable_1","var2":"test_variable_2"} + + state_id = "test_state_id" + mock_state = JBFSMState( + id = "test_state_id", + session_id = "test_session_id", + state = "test_state", + variables = {"var1":"test_variable_1","var2":"test_variable_2"} + ) + + mock_session = mock.Mock() + mock_session.commit = mock.AsyncMock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + with mock.patch("src.crud.uuid.uuid4", return_value = state_id): + result = await insert_state(session_id, state, variables) + + assert isinstance(result,JBFSMState) + assert result.id == mock_state.id + assert result.session_id == mock_state.session_id + assert result.state == mock_state.state + assert result.variables == mock_state.variables + + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_update_state_and_variables(): + + session_id = "test_session_id" + state = "test_state" + variables = {"var1":"test_variable_1","var2":"test_variable_2"} + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_session.commit = mock.AsyncMock() + + mock_execute = mock.AsyncMock() + mock_session.execute = mock_execute + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await update_state_and_variables(session_id, state, variables) + + assert result is not None + assert result == state + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_get_bot_by_session_id(): + + session_id = "test_session_id" + + mock_bot = JBBot( + id = "test_bot_id", + name = "Bot", + status = "active", + dsl = "test_dsl", + code = "test_code", + requirements = "test_requirements", + index_urls = {"index_url1":"test_url_1","index_url2":"test_url_2"}, + config_env = {}, + required_credentials = {"API_KEY":"test_api_key"}, + credentials = {}, + version = "0.0.1" + ) + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.first.return_value = mock_bot + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await get_bot_by_session_id(session_id) + + assert isinstance(result,JBBot) + assert result.id == mock_bot.id + assert result.name == mock_bot.name + assert result.status == mock_bot.status + assert result.dsl == mock_bot.dsl + assert result.code == mock_bot.code + assert result.requirements == mock_bot.requirements + assert result.index_urls == mock_bot.index_urls + assert result.config_env == mock_bot.config_env + assert result.required_credentials == mock_bot.required_credentials + assert result.credentials == mock_bot.credentials + assert result.version == mock_bot.version + + mock_session.execute.assert_awaited_once() + +@pytest.mark.asyncio +async def test_get_session_by_turn_id(): + + turn_id = "test_turn_id" + + mock_session_object = JBSession( + id = "test_session_id", + user_id = "test_user_id", + channel_id = "test_channel_id" + ) + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.first.return_value = mock_session_object + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await get_session_by_turn_id(turn_id) + + assert isinstance(result,JBSession) + assert result.id == mock_session_object.id + assert result.user_id == mock_session_object.user_id + assert result.channel_id == mock_session_object.channel_id + + mock_session.execute.assert_awaited_once() + +@pytest.mark.asyncio +async def test_get_all_bots(): + + mock_bot1 = JBBot(id="test_bot_1", name="Bot1", status="active") + mock_bot2 = JBBot(id="test_bot_2", name="Bot2", status="inactive") + + bot_list = [mock_bot1,mock_bot2] + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + mock_execute_result = mock.Mock() + mock_execute_result.scalars.return_value.all.return_value = bot_list + + mock_session.execute = mock.AsyncMock(return_value=mock_execute_result) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + result = await get_all_bots() + + assert isinstance(result,list) + assert len(result) == len(bot_list) + for item in result: + assert isinstance(item, JBBot) + assert item in bot_list + mock_session.execute.assert_awaited_once() + +def test_insert_jb_webhook_reference(): + + reference_id = "test_reference_id" + turn_id = "test_turn_id" + + with mock.patch.object(DBSessionHandler, 'get_sync_session') as mock_get_session: + + mock_session = mock.MagicMock() + mock_get_session.return_value = mock_session + + mock_session.begin.return_value = mock.MagicMock() + mock_session.commit.return_value = None + + result = insert_jb_webhook_reference(reference_id, turn_id) + + assert result is not None + assert result == reference_id + +@pytest.mark.asyncio +async def test_create_bot(): + + bot_id = "test_bot_id" + name = "test_bot" + code = "test_code" + requirements="test_requirements" + index_urls = {"index_url1":"test_url_1","index_url2":"test_url_2"} + required_credentials = {"API_KEY":"test_api_key"} + version = "0.0.1" + + mock_bot = JBBot( + id=bot_id, + name=name, + code=code, + requirements=requirements, + index_urls=index_urls, + required_credentials=required_credentials, + version=version, + ) + + mock_session = mock.Mock() + mock_session.commit = mock.AsyncMock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + result = await create_bot(bot_id, name, code, requirements, index_urls, required_credentials, version) + + assert result is not None + assert isinstance(result, JBBot) + assert result.id == mock_bot.id + assert result.name == mock_bot.name + assert result.code == mock_bot.code + assert result.requirements == mock_bot.requirements + assert result.index_urls == mock_bot.index_urls + assert result.required_credentials == mock_bot.required_credentials + assert result.version == mock_bot.version + + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_create_message(): + + message_id = "test_message_id" + message = Message( + message_type=MessageType.TEXT, + text=TextMessage(body="test_text_message"), + ) + + turn_id = "test_turn_id", + message_type = "text", + message = json.loads(getattr(message, message.message_type.value).model_dump_json(exclude_none=True)) + is_user_sent = True, + + mock_session = mock.Mock() + mock_session.commit = mock.AsyncMock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + with mock.patch("src.crud.uuid.uuid4", return_value = message_id): + + result = await create_message(turn_id, message_type, message, is_user_sent) + + assert result is not None + assert result == message_id + + mock_session.commit.assert_awaited_once() + +@pytest.mark.asyncio +async def test_update_user_language(): + + turn_id = "test_turn_id" + selected_language = "English" + + mock_session = mock.Mock() + mock_session.begin = mock.Mock(return_value=AsyncBeginMock()) + mock_session.commit = mock.AsyncMock() + + mock_execute = mock.AsyncMock() + mock_session.execute = mock_execute + + with mock.patch.object(DBSessionHandler, 'get_async_session', return_value=AsyncContextManagerMock(mock_session)): + + result = await update_user_language(turn_id, selected_language) + + assert result is None + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_awaited_once() \ No newline at end of file From 0bd7f9fbccc28fb04dd339e9ddde49d5a41e096d Mon Sep 17 00:00:00 2001 From: Ananya Agrawal Date: Wed, 18 Dec 2024 18:04:08 +0530 Subject: [PATCH 11/11] Modified retriever to improve test coverage --- retriever/qa_engine.py | 180 ----------------------------------------- 1 file changed, 180 deletions(-) delete mode 100644 retriever/qa_engine.py diff --git a/retriever/qa_engine.py b/retriever/qa_engine.py deleted file mode 100644 index ddc58309..00000000 --- a/retriever/qa_engine.py +++ /dev/null @@ -1,180 +0,0 @@ -import time -from enum import Enum -from abc import ABC, abstractmethod -from pydantic import BaseModel -from lib.document_collection import DocumentCollection -from lib.speech_processor import SpeechProcessor -from lib.translator import Translator -from lib.audio_converter import convert_to_wav_with_ffmpeg -from lib.model import Language -from .model import MediaFormat -from .model import IncorrectInputException -from .query_with_langchain import querying_with_langchain -from lib.jb_logging import Logger -from lib.logging_repository import LoggingRepository -import uuid - -logger = Logger("rag_engine") - - -class QueryResponse(BaseModel): - query: str - query_in_english: str = "" - answer: str - answer_in_english: str = "" - audio_output_url: str = "" - - -class LangchainQAModel(Enum): - GPT35_TURBO = "gpt-3.5-turbo-1106" - GPT4 = "gpt-4" - - -class QAEngine(ABC): - @abstractmethod - async def query( - self, - query: str = "", - speech_query_url: str = "", - input_language: Language = Language.EN, - output_format: MediaFormat = MediaFormat.TEXT, - ) -> QueryResponse: - pass - - -class LangchainQAEngine(QAEngine): - def __init__( - self, - document_collection: DocumentCollection, - speech_processor: SpeechProcessor, - translator: Translator, - model: LangchainQAModel, # Remove (Customizable) - logging_repository: LoggingRepository, - ): - self.document_collection = document_collection - self.speech_processor = speech_processor - self.translator = translator - self.model = model - self.logging_repository = logging_repository - - async def query( - self, - user_id: str, # Should be customized (Backend) - app_id: str, # Should be customized (Backend) - query: str = "", - speech_query_url: str = "", - prompt: str = "", # May not be needed - input_language: Language = Language.EN, - output_format: MediaFormat = MediaFormat.TEXT, - ) -> QueryResponse: - qa_id = str(uuid.uuid1()) - logger.info(f"Querying with QA ID: {qa_id}") - is_voice = False - audio_output_url = "" - if query == "" and speech_query_url == "": - error_message = "Query input is missing" - logger.error(error_message) - await self.logging_repository.insert_qa_log( - id=qa_id, - user_id=user_id, - app_id=app_id, - document_uuid="some-predefined-id", - input_language=input_language, - query=query, - audio_input_link=speech_query_url, - response="", - audio_output_link=audio_output_url, - retrieval_k_value=0, - retrieved_chunks=[], - prompt="", - gpt_model_name="", - status_code=422, - status_message=error_message, - response_time=10, - ) - raise IncorrectInputException(error_message) - - if query != "": - logger.info("Query Type: Text") - if output_format.name == "VOICE": - logger.info("Response Type: Voice") - is_voice = True - else: - logger.info("Response Type: Text") - else: - logger.info("Query Type: Voice") - wav_data = await convert_to_wav_with_ffmpeg(speech_query_url) - logger.info("Audio converted to wav") - query = await self.speech_processor.speech_to_text( - self.logging_repository, wav_data, input_language - ) - logger.info("Response Type: Voice") - is_voice = True - - logger.info(f"User Query: {query}") - logger.info( - "Query Input Language: " + input_language.value - ) - query_in_english = await self.translator.translate_text( - self.logging_repository, qa_id, query, input_language, Language.EN - ) - k, chunks, prompt, answer_in_english = await querying_with_langchain( - query_in_english, prompt, self.model.value - ) - logger.info("RAG is successful") - logger.info(f"Query in English: {query_in_english}") - logger.info(f"K value: {str(k)}") - logger.info(f"Chunks: {', '.join(chunks)}") - logger.info(f"Prompt: {prompt}") - logger.info(f"Answer in English: {answer_in_english}") - - answer = await self.translator.translate_text( - self.logging_repository, - qa_id, - answer_in_english, - Language.EN, - input_language, - ) - logger.info(f"Answer: {answer}") - if is_voice: - audio_content = await self.speech_processor.text_to_speech( - self.logging_repository, qa_id, answer, input_language - ) - time_stamp = time.strftime("%Y%m%d-%H%M%S") - filename = "output_audio_files/audio-output-" + time_stamp + ".mp3" - logger.info(f"Writing audio file: {filename}") - await self.document_collection.write_audio_file(filename, audio_content) - audio_output_url = await self.document_collection.audio_file_public_url( - filename - ) - logger.info( - f"Audio output URL: {audio_output_url}" - ) - - return_message = "RAG Engine process completed successfully" - logger.info(return_message) - await self.logging_repository.insert_qa_log( - id=qa_id, - user_id=user_id, - app_id=app_id, - document_uuid="some-predefined-id", - input_language=input_language, - query=query, - audio_input_link=speech_query_url, - response=answer, - audio_output_link=audio_output_url, - retrieval_k_value=k, - retrieved_chunks=chunks, - prompt=prompt, - gpt_model_name=self.model.value, - status_code=200, - status_message=return_message, - response_time=10, - ) - return QueryResponse( - query=query, - query_in_english=query_in_english, - answer=answer, - answer_in_english=answer_in_english, - audio_output_url=audio_output_url, - )