|
1 | | -from unittest.mock import AsyncMock, MagicMock |
2 | | - |
3 | 1 | import pytest |
4 | 2 | from slack_sdk.web.async_client import AsyncWebClient |
5 | 3 |
|
6 | 4 | from slack_bolt.context.say_stream.async_say_stream import AsyncSayStream |
7 | 5 | from slack_bolt.warning import ExperimentalWarning |
| 6 | +from tests.mock_web_api_server import ( |
| 7 | + cleanup_mock_web_api_server, |
| 8 | + setup_mock_web_api_server, |
| 9 | +) |
| 10 | +from tests.utils import remove_os_env_temporarily, restore_os_env |
8 | 11 |
|
9 | 12 |
|
10 | 13 | class TestAsyncSayStream: |
11 | | - def setup_method(self): |
12 | | - self.mock_client = MagicMock(spec=AsyncWebClient) |
13 | | - self.mock_client.chat_stream = AsyncMock() |
| 14 | + @pytest.fixture(scope="function", autouse=True) |
| 15 | + def setup_teardown(self): |
| 16 | + old_os_env = remove_os_env_temporarily() |
| 17 | + setup_mock_web_api_server(self) |
| 18 | + valid_token = "xoxb-valid" |
| 19 | + mock_api_server_base_url = "http://localhost:8888" |
| 20 | + try: |
| 21 | + self.web_client = AsyncWebClient(token=valid_token, base_url=mock_api_server_base_url) |
| 22 | + yield # run the test here |
| 23 | + finally: |
| 24 | + cleanup_mock_web_api_server(self) |
| 25 | + restore_os_env(old_os_env) |
14 | 26 |
|
15 | 27 | @pytest.mark.asyncio |
16 | 28 | async def test_missing_channel_raises(self): |
17 | | - say_stream = AsyncSayStream(client=self.mock_client, channel_id=None, thread_ts="111.222") |
| 29 | + say_stream = AsyncSayStream(client=self.web_client, channel_id=None, thread_ts="111.222") |
18 | 30 | with pytest.warns(ExperimentalWarning): |
19 | 31 | with pytest.raises(ValueError, match="channel"): |
20 | 32 | await say_stream() |
21 | 33 |
|
22 | 34 | @pytest.mark.asyncio |
23 | 35 | async def test_missing_thread_ts_raises(self): |
24 | | - say_stream = AsyncSayStream(client=self.mock_client, channel_id="C111", thread_ts=None) |
| 36 | + say_stream = AsyncSayStream(client=self.web_client, channel_id="C111", thread_ts=None) |
25 | 37 | with pytest.warns(ExperimentalWarning): |
26 | 38 | with pytest.raises(ValueError, match="thread_ts"): |
27 | 39 | await say_stream() |
28 | 40 |
|
29 | 41 | @pytest.mark.asyncio |
30 | 42 | async def test_default_params(self): |
31 | 43 | say_stream = AsyncSayStream( |
32 | | - client=self.mock_client, |
| 44 | + client=self.web_client, |
33 | 45 | channel_id="C111", |
34 | 46 | thread_ts="111.222", |
35 | 47 | team_id="T111", |
36 | 48 | user_id="U111", |
37 | 49 | ) |
38 | | - await say_stream() |
39 | | - |
40 | | - self.mock_client.chat_stream.assert_called_once_with( |
41 | | - channel="C111", |
42 | | - thread_ts="111.222", |
43 | | - recipient_team_id="T111", |
44 | | - recipient_user_id="U111", |
45 | | - ) |
| 50 | + stream = await say_stream() |
| 51 | + assert stream._stream_args == { |
| 52 | + "channel": "C111", |
| 53 | + "thread_ts": "111.222", |
| 54 | + "recipient_team_id": "T111", |
| 55 | + "recipient_user_id": "U111", |
| 56 | + "task_display_mode": None, |
| 57 | + } |
46 | 58 |
|
47 | 59 | @pytest.mark.asyncio |
48 | 60 | async def test_parameter_overrides(self): |
49 | 61 | say_stream = AsyncSayStream( |
50 | | - client=self.mock_client, |
| 62 | + client=self.web_client, |
51 | 63 | channel_id="C111", |
52 | 64 | thread_ts="111.222", |
53 | 65 | team_id="T111", |
54 | 66 | user_id="U111", |
55 | 67 | ) |
56 | | - await say_stream(channel="C222", thread_ts="333.444", recipient_team_id="T222", recipient_user_id="U222") |
| 68 | + stream = await say_stream(channel="C222", thread_ts="333.444", recipient_team_id="T222", recipient_user_id="U222") |
57 | 69 |
|
58 | | - self.mock_client.chat_stream.assert_called_once_with( |
59 | | - channel="C222", |
60 | | - thread_ts="333.444", |
61 | | - recipient_team_id="T222", |
62 | | - recipient_user_id="U222", |
63 | | - ) |
| 70 | + assert stream._stream_args == { |
| 71 | + "channel": "C222", |
| 72 | + "thread_ts": "333.444", |
| 73 | + "recipient_team_id": "T222", |
| 74 | + "recipient_user_id": "U222", |
| 75 | + "task_display_mode": None, |
| 76 | + } |
64 | 77 |
|
65 | 78 | @pytest.mark.asyncio |
66 | 79 | async def test_buffer_size_passthrough(self): |
67 | 80 | say_stream = AsyncSayStream( |
68 | | - client=self.mock_client, |
| 81 | + client=self.web_client, |
69 | 82 | channel_id="C111", |
70 | 83 | thread_ts="111.222", |
71 | 84 | ) |
72 | | - await say_stream(buffer_size=100) |
| 85 | + stream = await say_stream(buffer_size=100) |
73 | 86 |
|
74 | | - self.mock_client.chat_stream.assert_called_once_with( |
75 | | - buffer_size=100, |
76 | | - channel="C111", |
77 | | - thread_ts="111.222", |
78 | | - recipient_team_id=None, |
79 | | - recipient_user_id=None, |
80 | | - ) |
| 87 | + assert stream._buffer_size == 100 |
81 | 88 |
|
82 | 89 | @pytest.mark.asyncio |
83 | 90 | async def test_experimental_warning(self): |
84 | 91 | say_stream = AsyncSayStream( |
85 | | - client=self.mock_client, |
| 92 | + client=self.web_client, |
86 | 93 | channel_id="C111", |
87 | 94 | thread_ts="111.222", |
88 | 95 | ) |
|
0 commit comments