diff --git a/tests/integrations/langchain/test_langchain.py b/tests/integrations/langchain/test_langchain.py index 240a78e2cc..336be2fb1e 100644 --- a/tests/integrations/langchain/test_langchain.py +++ b/tests/integrations/langchain/test_langchain.py @@ -237,26 +237,6 @@ def get_word_length(word: str) -> int: return len(word) -global stream_result_mock # type: Mock -global llm_type # type: str - - -class MockOpenAI(ChatOpenAI): - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - for x in stream_result_mock(): - yield x - - @property - def _llm_type(self) -> str: - return llm_type - - def test_langchain_text_completion( sentry_init, capture_events, @@ -1488,8 +1468,22 @@ def test_langchain_openai_tools_agent_stream_with_config( def test_langchain_error(sentry_init, capture_events): - global llm_type - llm_type = "acme-llm" + class MockOpenAI(ChatOpenAI): + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + stream_result_mock = Mock(side_effect=ValueError("API rate limit error")) + + for x in stream_result_mock(): + yield x + + @property + def _llm_type(self) -> str: + return "acme-llm" sentry_init( integrations=[LangchainIntegration(include_prompts=True)], @@ -1508,8 +1502,6 @@ def test_langchain_error(sentry_init, capture_events): MessagesPlaceholder(variable_name="agent_scratchpad"), ] ) - global stream_result_mock - stream_result_mock = Mock(side_effect=ValueError("API rate limit error")) llm = MockOpenAI( model_name="gpt-3.5-turbo", temperature=0, @@ -1527,8 +1519,22 @@ def test_langchain_error(sentry_init, capture_events): def test_span_status_error(sentry_init, capture_events): - global llm_type - llm_type = "acme-llm" + class MockOpenAI(ChatOpenAI): + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + stream_result_mock = Mock(side_effect=ValueError("API rate limit error")) + + for x in stream_result_mock(): + yield x + + @property + def _llm_type(self) -> str: + return "acme-llm" sentry_init( integrations=[LangchainIntegration(include_prompts=True)], @@ -1547,8 +1553,6 @@ def test_span_status_error(sentry_init, capture_events): MessagesPlaceholder(variable_name="agent_scratchpad"), ] ) - global stream_result_mock - stream_result_mock = Mock(side_effect=ValueError("API rate limit error")) llm = MockOpenAI( model_name="gpt-3.5-turbo", temperature=0, @@ -1781,8 +1785,32 @@ def test_langchain_callback_list_existing_callback(sentry_init): def test_langchain_message_role_mapping(sentry_init, capture_events): """Test that message roles are properly normalized in langchain integration.""" - global llm_type - llm_type = "openai-chat" + + class MockOpenAI(ChatOpenAI): + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + stream_result_mock = Mock( + side_effect=[ + [ + ChatGenerationChunk( + type="ChatGenerationChunk", + message=AIMessageChunk(content="Test response"), + ), + ] + ] + ) + + for x in stream_result_mock(): + yield x + + @property + def _llm_type(self) -> str: + return "openai-chat" sentry_init( integrations=[LangchainIntegration(include_prompts=True)], @@ -1799,18 +1827,6 @@ def test_langchain_message_role_mapping(sentry_init, capture_events): ] ) - global stream_result_mock - stream_result_mock = Mock( - side_effect=[ - [ - ChatGenerationChunk( - type="ChatGenerationChunk", - message=AIMessageChunk(content="Test response"), - ), - ] - ] - ) - llm = MockOpenAI( model_name="gpt-3.5-turbo", temperature=0,