Przeglądaj źródła

ci(pytest): use one class to filter the nightly_run

Fu Hanxi 2 lat temu
rodzic
commit
17bbb2a5a2

+ 1 - 0
tools/ci/check_build_test_rules.py

@@ -524,6 +524,7 @@ if __name__ == '__main__':
                 extra_default_build_targets=extra_default_build_targets_list,
             )
         elif arg.action == 'check-test-scripts':
+            os.environ['INCLUDE_NIGHTLY_RUN'] = '1'
             check_test_scripts(
                 list(check_dirs),
                 exclude_dirs=_exclude_dirs,

+ 1 - 9
tools/ci/ci_build_apps.py

@@ -48,15 +48,7 @@ def get_pytest_apps(
     for case in pytest_cases:
         for app in case.apps:
             _paths.add(app.path)
-
-            if os.getenv('INCLUDE_NIGHTLY_RUN') == '1':
-                test_related_app_configs[app.path].add(app.config)
-            elif os.getenv('NIGHTLY_RUN') == '1':
-                if case.nightly_run:
-                    test_related_app_configs[app.path].add(app.config)
-            else:
-                if not case.nightly_run:
-                    test_related_app_configs[app.path].add(app.config)
+            test_related_app_configs[app.path].add(app.config)
 
     if not extra_default_build_targets:
         extra_default_build_targets = []

+ 53 - 3
tools/ci/idf_pytest/constants.py

@@ -4,10 +4,13 @@
 """
 Pytest Related Constants. Don't import third-party packages here.
 """
-
+import os
 import typing as t
 from dataclasses import dataclass
 
+from _pytest.python import Function
+from pytest_embedded.utils import to_list
+
 SUPPORTED_TARGETS = ['esp32', 'esp32s2', 'esp32c3', 'esp32s3', 'esp32c2', 'esp32c6', 'esp32h2']
 PREVIEW_TARGETS: t.List[str] = []  # this PREVIEW_TARGETS excludes 'linux' target
 DEFAULT_SDKCONFIG = 'default'
@@ -113,9 +116,56 @@ class PytestApp:
 class PytestCase:
     path: str
     name: str
+
     apps: t.Set[PytestApp]
+    target: str
 
-    nightly_run: bool
+    item: Function
 
     def __hash__(self) -> int:
-        return hash((self.path, self.name, self.apps, self.nightly_run))
+        return hash((self.path, self.name, self.apps, self.all_markers))
+
+    @property
+    def all_markers(self) -> t.Set[str]:
+        return {marker.name for marker in self.item.iter_markers()}
+
+    @property
+    def is_nightly_run(self) -> bool:
+        return 'nightly_run' in self.all_markers
+
+    @property
+    def target_markers(self) -> t.Set[str]:
+        return {marker for marker in self.all_markers if marker in TARGET_MARKERS}
+
+    @property
+    def env_markers(self) -> t.Set[str]:
+        return {marker for marker in self.all_markers if marker in ENV_MARKERS}
+
+    @property
+    def skipped_targets(self) -> t.Set[str]:
+        def _get_temp_markers_disabled_targets(marker_name: str) -> t.Set[str]:
+            temp_marker = self.item.get_closest_marker(marker_name)
+
+            if not temp_marker:
+                return set()
+
+            # temp markers should always use keyword arguments `targets` and `reason`
+            if not temp_marker.kwargs.get('targets') or not temp_marker.kwargs.get('reason'):
+                raise ValueError(
+                    f'`{marker_name}` should always use keyword arguments `targets` and `reason`. '
+                    f'For example: '
+                    f'`@pytest.mark.{marker_name}(targets=["esp32"], reason="IDF-xxxx, will fix it ASAP")`'
+                )
+
+            return set(to_list(temp_marker.kwargs['targets']))  # type: ignore
+
+        temp_skip_ci_targets = _get_temp_markers_disabled_targets('temp_skip_ci')
+        temp_skip_targets = _get_temp_markers_disabled_targets('temp_skip')
+
+        # in CI we skip the union of `temp_skip` and `temp_skip_ci`
+        if os.getenv('CI_JOB_ID'):
+            skip_targets = temp_skip_ci_targets.union(temp_skip_targets)
+        else:  # we use `temp_skip` locally
+            skip_targets = temp_skip_targets
+
+        return skip_targets

+ 63 - 72
tools/ci/idf_pytest/plugin.py

@@ -17,7 +17,7 @@ from pytest_embedded.plugin import parse_multi_dut_args
 from pytest_embedded.utils import find_by_suffix, to_list
 
 from .constants import DEFAULT_SDKCONFIG, PREVIEW_TARGETS, SUPPORTED_TARGETS, PytestApp, PytestCase
-from .utils import format_case_id, item_marker_names, item_skip_targets, merge_junit_files
+from .utils import format_case_id, merge_junit_files
 
 IDF_PYTEST_EMBEDDED_KEY = pytest.StashKey['IdfPytestEmbedded']()
 ITEM_FAILED_CASES_KEY = pytest.StashKey[list]()
@@ -38,6 +38,8 @@ class IdfPytestEmbedded:
         self.known_failure_patterns = self._parse_known_failure_cases_file(known_failure_cases_file)
         self.apps_list = apps_list
 
+        self.cases: t.List[PytestCase] = []
+
         self._failed_cases: t.List[t.Tuple[str, bool, bool]] = []  # (test_case_name, is_known_failure_cases, is_xfail)
 
     @property
@@ -72,6 +74,49 @@ class IdfPytestEmbedded:
 
         return patterns
 
+    @staticmethod
+    def get_param(item: Function, key: str, default: t.Any = None) -> t.Any:
+        # implement like this since this is a limitation of pytest, couldn't get fixture values while collecting
+        # https://github.com/pytest-dev/pytest/discussions/9689
+        if not hasattr(item, 'callspec'):
+            raise ValueError(f'Function {item} does not have params')
+
+        return item.callspec.params.get(key, default) or default
+
+    def item_to_pytest_case(self, item: Function) -> PytestCase:
+        count = 1
+        case_path = str(item.path)
+        case_name = item.originalname
+        target = self.target
+
+        # funcargs is not calculated while collection
+        if hasattr(item, 'callspec'):
+            count = item.callspec.params.get('count', 1)
+            app_paths = to_list(
+                parse_multi_dut_args(
+                    count,
+                    self.get_param(item, 'app_path', os.path.dirname(case_path)),
+                )
+            )
+            configs = to_list(parse_multi_dut_args(count, self.get_param(item, 'config', 'default')))
+            targets = to_list(parse_multi_dut_args(count, self.get_param(item, 'target', target)))
+        else:
+            app_paths = [os.path.dirname(case_path)]
+            configs = ['default']
+            targets = [target]
+
+        case_apps = set()
+        for i in range(count):
+            case_apps.add(PytestApp(app_paths[i], targets[i], configs[i]))
+
+        return PytestCase(
+            case_path,
+            case_name,
+            case_apps,
+            self.target,
+            item,
+        )
+
     @pytest.hookimpl(tryfirst=True)
     def pytest_sessionstart(self, session: Session) -> None:
         # same behavior for vanilla pytest-embedded '--target'
@@ -79,24 +124,17 @@ class IdfPytestEmbedded:
 
     @pytest.hookimpl(tryfirst=True)
     def pytest_collection_modifyitems(self, items: t.List[Function]) -> None:
-        # sort by file path and callspec.config
-        # implement like this since this is a limitation of pytest, couldn't get fixture values while collecting
-        # https://github.com/pytest-dev/pytest/discussions/9689
-        # after sort the test apps, the test may use the app cache to reduce the flash times.
-        def _get_param_config(_item: Function) -> str:
-            if hasattr(_item, 'callspec'):
-                return _item.callspec.params.get('config', DEFAULT_SDKCONFIG)  # type: ignore
-            return DEFAULT_SDKCONFIG  # type: ignore
-
-        items.sort(key=lambda x: (os.path.dirname(x.path), _get_param_config(x)))
-
-        # set default timeout 10 minutes for each case
+        item_to_case: t.Dict[Function, PytestCase] = {}
         for item in items:
+            # generate PytestCase for each item
+            case = self.item_to_pytest_case(item)
+            item_to_case[item] = case
+
+            # set default timeout 10 minutes for each case
             if 'timeout' not in item.keywords:
                 item.add_marker(pytest.mark.timeout(10 * 60))
 
-        # add markers for special markers
-        for item in items:
+            # add markers for special markers
             if 'supported_targets' in item.keywords:
                 for _target in SUPPORTED_TARGETS:
                     item.add_marker(_target)
@@ -109,11 +147,7 @@ class IdfPytestEmbedded:
 
             # add 'xtal_40mhz' tag as a default tag for esp32c2 target
             # only add this marker for esp32c2 cases
-            if (
-                self.target == 'esp32c2'
-                and 'esp32c2' in item_marker_names(item)
-                and 'xtal_26mhz' not in item_marker_names(item)
-            ):
+            if self.target == 'esp32c2' and 'esp32c2' in case.target_markers and 'xtal_26mhz' not in case.all_markers:
                 item.add_marker('xtal_40mhz')
 
         # filter all the test cases with "nightly_run" marker
@@ -121,20 +155,25 @@ class IdfPytestEmbedded:
             # Do not filter nightly_run cases
             pass
         elif os.getenv('NIGHTLY_RUN') == '1':
-            items[:] = [item for item in items if 'nightly_run' in item_marker_names(item)]
+            items[:] = [item for item in items if item_to_case[item].is_nightly_run]
         else:
-            items[:] = [item for item in items if 'nightly_run' not in item_marker_names(item)]
+            items[:] = [item for item in items if not item_to_case[item].is_nightly_run]
 
         # filter all the test cases with target and skip_targets
         items[:] = [
             item
             for item in items
-            if self.target in item_marker_names(item) and self.target not in item_skip_targets(item)
+            if self.target in item_to_case[item].target_markers
+            and self.target not in item_to_case[item].skipped_targets
         ]
 
         # filter all the test cases with cli option "config"
         if self.sdkconfig:
-            items[:] = [item for item in items if _get_param_config(item) == self.sdkconfig]
+            items[:] = [item for item in items if self.get_param(item, 'config', DEFAULT_SDKCONFIG) == self.sdkconfig]
+
+    def pytest_report_collectionfinish(self, items: t.List[Function]) -> None:
+        for item in items:
+            self.cases.append(self.item_to_pytest_case(item))
 
     def pytest_runtest_makereport(self, item: Function, call: CallInfo[None]) -> t.Optional[TestReport]:
         report = TestReport.from_item_and_call(item, call)
@@ -236,51 +275,3 @@ class IdfPytestEmbedded:
         if self.failed_cases:
             terminalreporter.section('Failed cases', bold=True, red=True)
             terminalreporter.line('\n'.join(self.failed_cases))
-
-
-class PytestCollectPlugin:
-    def __init__(self, target: str) -> None:
-        self.target = target
-        self.cases: t.List[PytestCase] = []
-
-    @staticmethod
-    def get_param(item: 'Function', key: str, default: t.Any = None) -> t.Any:
-        if not hasattr(item, 'callspec'):
-            raise ValueError(f'Function {item} does not have params')
-
-        return item.callspec.params.get(key, default) or default
-
-    def pytest_report_collectionfinish(self, items: t.List['Function']) -> None:
-        for item in items:
-            count = 1
-            case_path = str(item.path)
-            case_name = item.originalname
-            target = self.target
-            # funcargs is not calculated while collection
-            if hasattr(item, 'callspec'):
-                count = item.callspec.params.get('count', 1)
-                app_paths = to_list(
-                    parse_multi_dut_args(
-                        count,
-                        self.get_param(item, 'app_path', os.path.dirname(case_path)),
-                    )
-                )
-                configs = to_list(parse_multi_dut_args(count, self.get_param(item, 'config', 'default')))
-                targets = to_list(parse_multi_dut_args(count, self.get_param(item, 'target', target)))
-            else:
-                app_paths = [os.path.dirname(case_path)]
-                configs = ['default']
-                targets = [target]
-
-            case_apps = set()
-            for i in range(count):
-                case_apps.add(PytestApp(app_paths[i], targets[i], configs[i]))
-
-            self.cases.append(
-                PytestCase(
-                    case_path,
-                    case_name,
-                    case_apps,
-                    'nightly_run' in [marker.name for marker in item.iter_markers()],
-                )
-            )

+ 9 - 31
tools/ci/idf_pytest/script.py

@@ -2,7 +2,6 @@
 # SPDX-License-Identifier: Apache-2.0
 
 import io
-import os
 import typing as t
 from contextlib import redirect_stdout
 from pathlib import Path
@@ -14,7 +13,7 @@ from idf_py_actions.constants import SUPPORTED_TARGETS as TOOLS_SUPPORTED_TARGET
 from pytest_embedded.utils import to_list
 
 from .constants import PytestCase
-from .plugin import PytestCollectPlugin
+from .plugin import IdfPytestEmbedded
 
 
 def get_pytest_files(paths: t.List[str]) -> t.List[str]:
@@ -47,19 +46,6 @@ def get_pytest_cases(
 
     paths = to_list(paths)
 
-    origin_include_nightly_run_env = os.getenv('INCLUDE_NIGHTLY_RUN')
-    origin_nightly_run_env = os.getenv('NIGHTLY_RUN')
-
-    # disable the env vars to get all test cases
-    if 'INCLUDE_NIGHTLY_RUN' in os.environ:
-        os.environ.pop('INCLUDE_NIGHTLY_RUN')
-
-    if 'NIGHTLY_RUN' in os.environ:
-        os.environ.pop('NIGHTLY_RUN')
-
-    # collect all cases
-    os.environ['INCLUDE_NIGHTLY_RUN'] = '1'
-
     cases: t.List[PytestCase] = []
     pytest_scripts = get_pytest_files(paths)  # type: ignore
     if not pytest_scripts:
@@ -67,7 +53,7 @@ def get_pytest_cases(
         return cases
 
     for target in targets:
-        collector = PytestCollectPlugin(target)
+        collector = IdfPytestEmbedded(target)
 
         with io.StringIO() as buf:
             with redirect_stdout(buf):
@@ -77,22 +63,14 @@ def get_pytest_cases(
                 if filter_expr:
                     cmd.extend(['-k', filter_expr])
                 res = pytest.main(cmd, plugins=[collector])
-            if res.value != ExitCode.OK:
-                if res.value == ExitCode.NO_TESTS_COLLECTED:
-                    print(f'WARNING: no pytest app found for target {target} under paths {", ".join(paths)}')
-                else:
-                    print(buf.getvalue())
-                    raise RuntimeError(
-                        f'pytest collection failed at {", ".join(paths)} with command \"{" ".join(cmd)}\"'
-                    )
 
-        cases.extend(collector.cases)
-
-    # revert back the env vars
-    if origin_include_nightly_run_env is not None:
-        os.environ['INCLUDE_NIGHTLY_RUN'] = origin_include_nightly_run_env
+        if res.value != ExitCode.OK:
+            if res.value == ExitCode.NO_TESTS_COLLECTED:
+                print(f'WARNING: no pytest app found for target {target} under paths {", ".join(paths)}')
+            else:
+                print(buf.getvalue())
+                raise RuntimeError(f'pytest collection failed at {", ".join(paths)} with command \"{" ".join(cmd)}\"')
 
-    if origin_nightly_run_env is not None:
-        os.environ['NIGHTLY_RUN'] = origin_nightly_run_env
+        cases.extend(collector.cases)
 
     return cases

+ 1 - 55
tools/ci/idf_pytest/utils.py

@@ -6,10 +6,7 @@ import os
 import typing as t
 from xml.etree import ElementTree as ET
 
-from _pytest.nodes import Item
-from pytest_embedded.utils import to_list
-
-from .constants import ENV_MARKERS, TARGET_MARKERS
+from .constants import TARGET_MARKERS
 
 
 def format_case_id(target: t.Optional[str], config: t.Optional[str], case: str, is_qemu: bool = False) -> str:
@@ -23,57 +20,6 @@ def format_case_id(target: t.Optional[str], config: t.Optional[str], case: str,
     return '.'.join(parts)
 
 
-def item_marker_names(item: Item) -> t.List[str]:
-    return [marker.name for marker in item.iter_markers()]
-
-
-def item_target_marker_names(item: Item) -> t.List[str]:
-    res = set()
-    for marker in item.iter_markers():
-        if marker.name in TARGET_MARKERS:
-            res.add(marker.name)
-
-    return sorted(res)
-
-
-def item_env_marker_names(item: Item) -> t.List[str]:
-    res = set()
-    for marker in item.iter_markers():
-        if marker.name in ENV_MARKERS:
-            res.add(marker.name)
-
-    return sorted(res)
-
-
-def item_skip_targets(item: Item) -> t.List[str]:
-    def _get_temp_markers_disabled_targets(marker_name: str) -> t.List[str]:
-        temp_marker = item.get_closest_marker(marker_name)
-
-        if not temp_marker:
-            return []
-
-        # temp markers should always use keyword arguments `targets` and `reason`
-        if not temp_marker.kwargs.get('targets') or not temp_marker.kwargs.get('reason'):
-            raise ValueError(
-                f'`{marker_name}` should always use keyword arguments `targets` and `reason`. '
-                f'For example: '
-                f'`@pytest.mark.{marker_name}(targets=["esp32"], reason="IDF-xxxx, will fix it ASAP")`'
-            )
-
-        return to_list(temp_marker.kwargs['targets'])  # type: ignore
-
-    temp_skip_ci_targets = _get_temp_markers_disabled_targets('temp_skip_ci')
-    temp_skip_targets = _get_temp_markers_disabled_targets('temp_skip')
-
-    # in CI we skip the union of `temp_skip` and `temp_skip_ci`
-    if os.getenv('CI_JOB_ID'):
-        skip_targets = list(set(temp_skip_ci_targets).union(set(temp_skip_targets)))
-    else:  # we use `temp_skip` locally
-        skip_targets = temp_skip_targets
-
-    return skip_targets
-
-
 def get_target_marker_from_expr(markexpr: str) -> str:
     candidates = set()
     # we use `-m "esp32 and generic"` in our CI to filter the test cases