From 0edb6ca3d733d8cb3fd1ddc7994bde61991ac4c2 Mon Sep 17 00:00:00 2001
From: Frank Hoffmann <44680962+15r10nk@users.noreply.github.com>
Date: Sun, 12 Jan 2025 17:03:09 +0100
Subject: [PATCH] fix: check for pytest compatibility (#94)

---
 executing/__init__.py      |  5 ++++-
 executing/_pytest_utils.py | 16 ++++++++++++++++
 tests/conftest.py          | 25 +++++++++++++++++++++++++
 tests/test_main.py         |  5 -----
 tests/test_pytest.py       |  5 +++++
 5 files changed, 50 insertions(+), 6 deletions(-)
 create mode 100644 executing/_pytest_utils.py
 create mode 100644 tests/conftest.py

diff --git a/executing/__init__.py b/executing/__init__.py
index b645197..e5181a5 100644
--- a/executing/__init__.py
+++ b/executing/__init__.py
@@ -10,6 +10,9 @@
 from collections import namedtuple
 _VersionInfo = namedtuple('_VersionInfo', ('major', 'minor', 'micro'))
 from .executing import Source, Executing, only, NotOneValueFound, cache, future_flags
+
+from ._pytest_utils import is_pytest_compatible
+
 try:
     from .version import __version__ # type: ignore[import]
     if "dev" in __version__:
@@ -22,4 +25,4 @@
     __version_info__ = _VersionInfo(*map(int, __version__.split('.')))
 
 
-__all__ = ["Source"]
+__all__ = ["Source","is_pytest_compatible"]
diff --git a/executing/_pytest_utils.py b/executing/_pytest_utils.py
new file mode 100644
index 0000000..fab8693
--- /dev/null
+++ b/executing/_pytest_utils.py
@@ -0,0 +1,16 @@
+import sys
+
+
+
+def is_pytest_compatible() -> bool:
+    """ returns true if executing can be used for expressions inside assert statements which are rewritten by pytest
+    """
+    if sys.version_info < (3, 11):
+        return False
+
+    try:
+        import pytest
+    except ImportError:
+        return False
+
+    return pytest.version_tuple >= (8, 3, 4)
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..5108348
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,25 @@
+
+
+from typing import Optional, Sequence, Union
+from executing._pytest_utils import is_pytest_compatible
+import _pytest.assertion.rewrite as rewrite
+import importlib.machinery
+import types
+
+if not is_pytest_compatible():
+    original_find_spec = rewrite.AssertionRewritingHook.find_spec
+
+
+    def find_spec(
+        self,
+        name: str,
+        path: Optional[Sequence[Union[str, bytes]]] = None,
+        target: Optional[types.ModuleType] = None,
+    ) -> Optional[importlib.machinery.ModuleSpec]:
+
+        if name == "tests.test_main":
+            return None
+        return original_find_spec(self, name, path, target)
+
+
+    rewrite.AssertionRewritingHook.find_spec = find_spec
diff --git a/tests/test_main.py b/tests/test_main.py
index a3f92ee..e3bc9d6 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -1,9 +1,4 @@
 # -*- coding: utf-8 -*-
-"""
-
-assert rewriting will break executing
-PYTEST_DONT_REWRITE
-"""
 from __future__ import print_function, division
 import ast
 import contextlib
diff --git a/tests/test_pytest.py b/tests/test_pytest.py
index 281598d..5cbe0a2 100644
--- a/tests/test_pytest.py
+++ b/tests/test_pytest.py
@@ -6,6 +6,7 @@
 from time import sleep
 
 import asttokens
+from executing._pytest_utils import is_pytest_compatible
 import pytest
 from littleutils import SimpleNamespace
 
@@ -124,6 +125,10 @@ def check_manual_linecache(filename):
 def test_exception_catching():
     frame = inspect.currentframe()
 
+    if is_pytest_compatible():
+        assert isinstance(Source.executing(frame).node,ast.Call)
+        return 
+
     executing.executing.TESTING = True  # this is already the case in all other tests
     # Sanity check that this operation usually raises an exception.
     # This actually depends on executing not working in the presence of pytest.