diff --git a/README.md b/README.md index 004ae7d..bd4fffe 100644 --- a/README.md +++ b/README.md @@ -413,7 +413,7 @@ When defining an **ASTF profile** you likely want to define the `get_astf_profil This can either be a standalone function which creates the profile from scratch or it can use a "native" TRex profile file. The latter is preferred as it leads to simpler tuning and debugging. For examples see `http_trex_profile`. -When defining an **STF profile** you need to define `get_stf_profile` which should return a path to an already existing +When defining an **STF profile** you might want to define `get_stf_profile` which should return a path to an already existing [traffic profile](https://trex-tgn.cisco.com/trex/doc/trex_manual.html#_traffic_yaml_f_argument_of_stateful). You might also want to change some things in the [platform config](https://trex-tgn.cisco.com/trex/doc/trex_manual.html#_platform_yaml_cfg_argument) which can be done by defining an `stf_config_hook`. This function gets a `ConfigBuilder` instance with the config that would be sent to diff --git a/assets/trex/traffic_profiles/trex_client_manager.py b/assets/trex/traffic_profiles/trex_client_manager.py index 8d31a43..78a966f 100644 --- a/assets/trex/traffic_profiles/trex_client_manager.py +++ b/assets/trex/traffic_profiles/trex_client_manager.py @@ -1,14 +1,10 @@ import copy import os -import subprocess import warnings -from enum import Enum from pathlib import Path from time import sleep, time -from typing import Dict, Literal, Self, Sequence, Tuple +from typing import Dict, Literal, Self -from conftest import get_trex_executor, get_trex_internal, send_pcap_to_trex -from lbr_testsuite.executable import executable from lbr_testsuite.trex import ( TRexAdvancedStateful, TRexManager, @@ -19,20 +15,15 @@ from lbr_trex_client.stf.trex_stf_lib.trex_client import CTRexClient from pytest import FixtureRequest -# ASTFProfile needs to be imported exactly like this -# otherwise it fails an isintance() check +# these need to be imported exactly like this +# otherwise they fail type introspection from trex.astf import trex_astf_profile +from trex.common.trex_exceptions import TRexError + from util.add_vlan import edit_vlan from util.config_builder import ConfigBuilder from util.suri_util import RunInfo - -PcapList = Sequence[Tuple[str, int | float]] - - -class TrexMode(Enum): - STL = (0,) - ASTF = (1,) - STF = 2 +from util.trex_util import PcapList, TrexMode, mkdir_remote, send_to_remote class BaseTrexClientManager: @@ -42,16 +33,18 @@ class BaseTrexClientManager: Subclasses are created as `MyProfile(BaseTrexClientManager, pcaps)`. `pcaps: PcapList` is a list of (str, int) tuples, where int is: + - cps in STF - cps in ASTF - the divisor for `self.BASE_IPG_USEC` in STL - - not used in STF by default (depends on profile implementation) """ pcaps: PcapList + multiplier: float | None = None + duration: int | None = None + _stf_config_path: Path | None = None BASE_IPG_USEC = 12.0 # ~1 Gbps at 1500 bytes per packet PCAP_PATH_PREFIX = Path(__file__).parent / "pcaps" - REMOTE_PCAP_PATH_PREFIX = Path("/tmp/pcaps") def __new__(cls, *args, **kwargs) -> Self: if cls is BaseTrexClientManager: @@ -74,18 +67,29 @@ def __init__( self.pcaps = copy.deepcopy(self.profile_pcaps) self.mode = mode self.vlan_id = target_vlan - self.multiplier: float | None = None - self.duration: int | None = None if len(self.pcaps) < 1: raise ValueError("self.pcaps must contain at least one pcap") + trex_gen = request.config.getoption("--trex-generator") + trex_host = trex_gen[0].split(",") + trex_hostname = trex_host[0] + trex_pcie = trex_host[1] + match self.mode: case TrexMode.STL: # STL mode can only send one pcap at a time so it either # needs to merge them together or replay them one by one # currently it replays them one by one + + # if STL mode gets used more in the future this should create a merged + # pcap that is at least 1M packets long, since there is a lot of overhead + # with small files + self.stl_generator: TRexStateless = manager.request_stateless(request) + self.trex_version = ( + self.stl_generator.get_handler().get_server_version()["version"] + ) self.stl_generator.set_dst_mac(target_mac) if target_vlan != 0: @@ -97,7 +101,8 @@ def __init__( if target_vlan != 0: pcap_path = Path(edit_vlan(str(pcap_path), target_vlan)) self.pcaps[i] = (pcap_path.name, pcap[1]) - send_pcap_to_trex(str(pcap_path), request) + pcap_remote_path = self.get_remote_data_path(pcap_path) + send_to_remote(pcap_path, trex_hostname, pcap_remote_path) case TrexMode.ASTF: self.client: TRexAdvancedStateful = manager.request_stateful( @@ -106,6 +111,9 @@ def __init__( self.server: TRexAdvancedStateful = manager.request_stateful( request, role="server" ) + self.trex_version = self.server.get_handler().get_server_version()[ + "version" + ] self.client.set_dst_mac(self.server.get_src_mac()) self.server.set_dst_mac(self.client.get_src_mac()) @@ -115,37 +123,11 @@ def __init__( self.server.set_vlan(target_vlan) case TrexMode.STF: - trex_gen = request.config.getoption("--trex-generator") - trex_host = trex_gen[0].split(",") - trex_hostname = trex_host[0] - trex_pcie = trex_host[1] - self.stf_generator = CTRexClient(trex_hostname) self.trex_version = self.stf_generator.get_trex_version()["Version"] - parent_dir_path = str(self.get_remote_data_path(Path(""))) - parent_dir = executable.Tool( - f"mkdir -p {parent_dir_path} && chmod 777 {parent_dir_path}", - executor=get_trex_executor(request), - sudo=True, - ) - parent_dir.run() - - username: str = request.config.getoption("--user") - remote: str = get_trex_internal(request) - - profile_path = self.get_stf_profile() - profile_remote_path = str(self.get_remote_data_path(profile_path)) - subprocess.run( - [ - "rsync", - "-z", - "--checksum", - "--update", - str(profile_path), - f"{username}@{remote}:{profile_remote_path}", - ] - ) + parent_dir_path = self.get_remote_data_path(Path("")) + mkdir_remote(parent_dir_path, trex_hostname) os.makedirs("tmp", exist_ok=True) config = ConfigBuilder( @@ -161,35 +143,23 @@ def __init__( # similarly this syntax deletes it config.delete_option("[0].port_info.vlan") config = self.stf_config_hook(config) - config_path = config.build() - config_remote_path = str(self.get_remote_data_path(Path(config_path))) + config_path = Path(config.build()) + config_remote_path = self.get_remote_data_path(config_path) self.remote_stf_config = config_remote_path - - subprocess.run( - [ - "rsync", - "-z", - "--checksum", - "--update", - config_path, - f"{username}@{remote}:{config_remote_path}", - ] - ) + send_to_remote(config_path, trex_hostname, config_remote_path) print("Uploading pcaps. This might take a while.") - for pcap, _ in self.pcaps: - pcap_path = self.PCAP_PATH_PREFIX / pcap - pcap_remote_path = str(self.get_remote_data_path(pcap_path)) - subprocess.run( - [ - "rsync", - "-z", - "--checksum", - "--update", - str(pcap_path), - f"{username}@{remote}:{pcap_remote_path}", - ] - ) + for i, pcap in enumerate(self.pcaps): + pcap_path = self.PCAP_PATH_PREFIX / pcap[0] + if target_vlan != 0: + pcap_path = Path(edit_vlan(str(pcap_path), target_vlan)) + self.pcaps[i] = (pcap_path.name, pcap[1]) + pcap_remote_path = self.get_remote_data_path(pcap_path) + send_to_remote(pcap_path, trex_hostname, pcap_remote_path) + + profile_path = self.get_stf_profile() + profile_remote_path = self.get_remote_data_path(profile_path) + send_to_remote(profile_path, trex_hostname, profile_remote_path) def get_remote_data_path(self, local_path: Path) -> Path: """ @@ -197,7 +167,7 @@ def get_remote_data_path(self, local_path: Path) -> Path: A directory is created from the output of `get_remote_data_path(Path(""))`. """ - return self.REMOTE_PCAP_PATH_PREFIX / local_path.name + return Path(f"/opt/trex/{self.trex_version}/pcaps") / local_path.name def get_astf_profile(self, multiplier: float) -> trex_astf_profile.ASTFProfile: """ @@ -244,10 +214,45 @@ def get_stf_profile(self) -> Path: Returns the *local* path to the stateful profile config. The remote path is handled by `get_remote_data_path`. """ - raise NotImplementedError("no default implementation in BaseTrexClientManager") - # it probably is possible to construct a sane default from - # just `self.pcaps`, but this is intended to return a file path - # to an existing `.yaml` file + if self._stf_config_path is not None: + return self._stf_config_path + + self._stf_config_path = Path("tmp/stf_trex_profile.yaml").absolute() + with open(self._stf_config_path, mode="w+") as f: + f.write("[]\n") + profile = ConfigBuilder(str(self._stf_config_path), str(self._stf_config_path)) + profile.add_option("[0].duration", 9999) + profile.add_option( + "[0].generator", + { + "distribution": "seq", + "clients_start": "16.0.0.1", + "clients_end": "16.0.0.255", + "servers_start": "48.0.0.1", + "servers_end": "48.0.255.255", + "clients_per_gb": 200, + "min_clients": 100, + "dual_port_mask": "1.0.0.0", + "tcp_aging": 0, + "udp_aging": 0, + }, + ) + + for i, pcap in enumerate(self.pcaps): + profile.add_option( + f"[0].cap_info.[{i}]", + { + "name": f"pcaps/{pcap[0]}", + "cps": pcap[1], + "ipg": 100, + "rtt": 100, + "w": 1, + }, + ) + + os.makedirs("tmp", exist_ok=True) + profile.build() + return self._stf_config_path def stf_config_hook(self, config: ConfigBuilder) -> ConfigBuilder: """ @@ -314,14 +319,23 @@ def run(self, blocking=True) -> None: pcap_index = 0 while elapsed < self.duration: pcap = self.pcaps[pcap_index] - client.push_remote( - pcap_filename=str(self.get_remote_data_path(Path(pcap[0]))), - ports=[0], - ipg_usec=self.BASE_IPG_USEC / pcap[1], - speedup=self.multiplier, - count=1, - duration=int(self.duration - elapsed), - ) + try: + client.push_remote( + pcap_filename=str(self.get_remote_data_path(Path(pcap[0]))), + ports=[0], + ipg_usec=self.BASE_IPG_USEC / pcap[1], + speedup=self.multiplier, + count=1, + duration=int(self.duration - elapsed), + ) + except TRexError: + # sometimes trex takes a while to stop the previous stream properly + sleep(0.05) + + # intentionally not 100% of the sleep, because with very small + # pcaps this could run for a very long time even with a short duration + start += 0.03 + continue elapsed = time() - start pcap_index = (pcap_index + 1) % len(self.pcaps) @@ -340,9 +354,9 @@ def run(self, blocking=True) -> None: self.stf_generator.start_trex( f=str(self.get_remote_data_path(self.get_stf_profile()).absolute()), - d=self.duration, - m=self.multiplier, - cfg=self.remote_stf_config, + d=str(self.duration), + m=str(self.multiplier), + cfg=str(self.remote_stf_config), ) if blocking: diff --git a/conftest.py b/conftest.py index 1274cbf..7332072 100644 --- a/conftest.py +++ b/conftest.py @@ -127,6 +127,24 @@ def pytest_addoption(parser): action="store", help=("Generate traffic with this VLAN ID. 0 (default) for untagged."), ) + parser.addoption( + "--prefer-trex-mode", + type=str, + default=None, + action="store", + help=( + "Run tests with the specified trex mode if available. If not, fallback to default." + ), + ) + parser.addoption( + "--force-trex-mode", + type=str, + default=None, + action="store", + help=( + "Run tests with the specified trex mode if available. If not, skip test." + ), + ) def get_suri_executor(request) -> remote_executor.Executor: @@ -227,20 +245,6 @@ def return_filename(pcap_filename): return match.group(0) -def send_pcap_to_trex(pcap_filename, request): - - pcaps_dir_trex = executable.Tool( - "mkdir -p /tmp/pcaps/ && chmod 777 /tmp/pcaps/", - executor=get_trex_executor(request), - sudo=True, - ) - pcaps_dir_trex.run() - - os.system( - f"rsync -z --checksum --update {pcap_filename} $(whoami)@{get_trex_internal(request)}.liberouter.org:/tmp/pcaps" - ) - - @pytest.fixture(scope="function") def get_test_name(request): """Function, that returns a name of a current test""" diff --git a/pytest_start.sh b/pytest_start.sh index 0de7888..a25883c 100755 --- a/pytest_start.sh +++ b/pytest_start.sh @@ -20,6 +20,8 @@ usage(){ echo "-ht | --heatup [TIME] to specify the duration for which to wait before measuring statistics" echo "-f | --filter [rules/norules] starts Suricata with/without rules" echo "-pc | --pcap [PATH] to specify the pcap file to send to Suricata. Also sets --defined-tests to *only* pcap_replay" + echo "-pm | --prefer-trex-mode [MODE] to suggest a mode for TRex. If unavailable tests use their defaults." + echo "-fm | --force-trex-mode [MODE] to force a TRex mode. If unavailable tests get skipped. Overrides -pm" exit 0 } @@ -46,6 +48,8 @@ while [ "$#" -gt 0 ]; do *) filter="$2";; esac; shift 2 ;; -pc | --pcap) pcap_replay="$2"; shift 2 ;; + -pm | --prefer-trex-mode) trex_mode_flags+="--prefer-trex-mode $2 " shift 2 ;; + -fm | --force-trex-mode) trex_mode_flags+="--force-trex-mode $2 "; shift 2 ;; -h | --help) usage ; shift ;; --) shift; read -a extra_args <<< "$@"; break ;; *) >&2 echo unsupported option: $1 @@ -203,6 +207,7 @@ do --trex-generator="$trex_server_hostname,$trex_server_port_2" \ --remote-host="$suricata_server" --param-file="param.py" \ --trex-force-use \ + $trex_mode_flags \ --traffic-duration="$defined_time" \ --heatup-duration="$heatup_duration" \ -k "$filter" \ diff --git a/tests/http_https_smb_simple/test_http_https_smb_simple.py b/tests/http_https_smb_simple/test_http_https_smb_simple.py index 33288ea..09138ae 100644 --- a/tests/http_https_smb_simple/test_http_https_smb_simple.py +++ b/tests/http_https_smb_simple/test_http_https_smb_simple.py @@ -21,6 +21,7 @@ HttpHttpsSmbProfile, ) from conftest import kill_pytest, get_trex_multi, suri_interface_bind, Suri_conf +from util.trex_util import TrexMode, get_trex_mode @pytest.mark.parametrize( @@ -71,8 +72,9 @@ def test_http_https_smb( utilized_programs_info=utilized_programs_info, ) + trex_mode = get_trex_mode(request, [TrexMode.ASTF, TrexMode.STF]) trex_client = HttpHttpsSmbProfile( - trex_manager, request, get_target_mac, get_target_vlan + trex_manager, request, get_target_mac, get_target_vlan, mode=trex_mode ) test_variant_name = f"{suri_conf.test_name}_{rules_config['name']}" diff --git a/tests/http_simple/test_http_simple.py b/tests/http_simple/test_http_simple.py index 7d0d654..3bcc710 100644 --- a/tests/http_simple/test_http_simple.py +++ b/tests/http_simple/test_http_simple.py @@ -19,6 +19,7 @@ from util.suri_util import save_stats, TestInfo, RunInfo from assets.trex.traffic_profiles.http_trex_profile.profile import HttpProfile from conftest import kill_pytest, get_trex_multi, suri_interface_bind, Suri_conf +from util.trex_util import TrexMode, get_trex_mode @pytest.mark.parametrize( @@ -69,7 +70,10 @@ def test_http_simple( utilized_programs_info=utilized_programs_info, ) - trex_client = HttpProfile(trex_manager, request, get_target_mac, get_target_vlan) + trex_mode = get_trex_mode(request, [TrexMode.ASTF, TrexMode.STF]) + trex_client = HttpProfile( + trex_manager, request, get_target_mac, get_target_vlan, mode=trex_mode + ) test_variant_name = f"{suri_conf.test_name}_{rules_config['name']}" diff --git a/tests/https_simple/test_https_simple.py b/tests/https_simple/test_https_simple.py index 8f40464..323b6bb 100644 --- a/tests/https_simple/test_https_simple.py +++ b/tests/https_simple/test_https_simple.py @@ -19,6 +19,7 @@ from util.suri_util import save_stats, TestInfo, RunInfo from assets.trex.traffic_profiles.https_trex_profile.profile import HttpsProfile from conftest import kill_pytest, get_trex_multi, suri_interface_bind, Suri_conf +from util.trex_util import TrexMode, get_trex_mode @pytest.mark.parametrize( @@ -69,7 +70,10 @@ def test_https_simple( utilized_programs_info=utilized_programs_info, ) - trex_client = HttpsProfile(trex_manager, request, get_target_mac, get_target_vlan) + trex_mode = get_trex_mode(request, [TrexMode.ASTF, TrexMode.STF]) + trex_client = HttpsProfile( + trex_manager, request, get_target_mac, get_target_vlan, mode=trex_mode + ) test_variant_name = f"{suri_conf.test_name}_{rules_config['name']}" trex_multipliers: List[float] = get_trex_multi( diff --git a/tests/nfs_smb_simple/test_nfs_smb_simple.py b/tests/nfs_smb_simple/test_nfs_smb_simple.py index e28ff66..c038afc 100644 --- a/tests/nfs_smb_simple/test_nfs_smb_simple.py +++ b/tests/nfs_smb_simple/test_nfs_smb_simple.py @@ -19,6 +19,7 @@ from util.suri_util import save_stats, TestInfo, RunInfo from assets.trex.traffic_profiles.nfs_smb_trex_profile.profile import NfsSmbProfile from conftest import kill_pytest, get_trex_multi, suri_interface_bind, Suri_conf +from util.trex_util import TrexMode, get_trex_mode @pytest.mark.parametrize( @@ -69,7 +70,10 @@ def test_nfs_smb( utilized_programs_info=utilized_programs_info, ) - trex_client = NfsSmbProfile(trex_manager, request, get_target_mac, get_target_vlan) + trex_mode = get_trex_mode(request, [TrexMode.ASTF, TrexMode.STF]) + trex_client = NfsSmbProfile( + trex_manager, request, get_target_mac, get_target_vlan, mode=trex_mode + ) test_variant_name = f"{suri_conf.test_name}_{rules_config['name']}" trex_multipliers: List[float] = get_trex_multi( diff --git a/tests/pcap_replay/test_pcap_replay.py b/tests/pcap_replay/test_pcap_replay.py index 506f47c..8301338 100644 --- a/tests/pcap_replay/test_pcap_replay.py +++ b/tests/pcap_replay/test_pcap_replay.py @@ -10,6 +10,8 @@ pytest --trex-generator="trex,0000:65:00.0" --remote-host="claret,0000:3b:00.0" -s --log-level=info """ +from pathlib import Path + import pytest import signal @@ -18,13 +20,13 @@ from util.add_vlan import edit_vlan from util.suricata_manager import Suricata_manager, SuriDown from util.suri_util import save_stats, TestInfo, RunInfo +from util.trex_util import mkdir_remote, send_to_remote from conftest import ( + get_trex_internal, kill_pytest, get_trex_multi, suri_interface_bind, Suri_conf, - send_pcap_to_trex, - return_filename, ) @@ -86,8 +88,10 @@ def test_pcap_replay( get_settings_file, suri_conf.server, suri_conf.pcie, test_variant_name ) - pcap_filename = edit_vlan(get_path_to_pcap, get_target_vlan) - send_pcap_to_trex(pcap_filename, request) + trex_hostname = get_trex_internal(request) + pcap_path = Path(edit_vlan(get_path_to_pcap, get_target_vlan)) + mkdir_remote(Path("/tmp/pcaps"), trex_hostname) + send_to_remote(pcap_path, trex_hostname, Path("/tmp/pcaps") / pcap_path.name) for idx, multiplier in enumerate(trex_multipliers, 1): run_info = RunInfo(multiplier=multiplier) @@ -105,7 +109,7 @@ def test_pcap_replay( pytest.fail("Suricata is down.") traffic_generator.get_handler().push_remote( - pcap_filename=f"/tmp/pcaps/{return_filename(pcap_filename)}", + pcap_filename=f"/tmp/pcaps/{pcap_path.name}", ports=[0], ipg_usec=100, speedup=200 * run_info.multiplier, diff --git a/tests/web_50_sites/test_web_50_sites.py b/tests/web_50_sites/test_web_50_sites.py index 3f1c328..ff66911 100644 --- a/tests/web_50_sites/test_web_50_sites.py +++ b/tests/web_50_sites/test_web_50_sites.py @@ -11,11 +11,11 @@ from typing import List from lbr_testsuite import trex -from assets.trex.traffic_profiles.trex_client_manager import TrexMode from assets.trex.traffic_profiles.web_50_sites_trex_profile import Web50SitesProfile from util.suricata_manager import Suricata_manager, SuriDown from util.suri_util import save_stats, TestInfo, RunInfo from conftest import kill_pytest, get_trex_multi, suri_interface_bind, Suri_conf +from util.trex_util import TrexMode, get_trex_mode @pytest.mark.parametrize( @@ -66,8 +66,9 @@ def test_web_50_sites( utilized_programs_info=utilized_programs_info, ) + trex_mode = get_trex_mode(request, [TrexMode.STF]) trex_client = Web50SitesProfile( - trex_manager, request, get_target_mac, get_target_vlan, mode=TrexMode.STF + trex_manager, request, get_target_mac, get_target_vlan, mode=trex_mode ) test_variant_name = f"{suri_conf.test_name}_{rules_config['name']}" diff --git a/util/config_builder.py b/util/config_builder.py index 971a03d..239e03b 100644 --- a/util/config_builder.py +++ b/util/config_builder.py @@ -5,7 +5,7 @@ from ruamel.yaml import YAML from yamlpath import Processor from yamlpath.enums.yamlvalueformats import YAMLValueFormats -from yamlpath.wrappers import ConsolePrinter +from yamlpath.wrappers import ConsolePrinter, NodeCoords def update_recursively(destination: Dict, source: Dict, extend_lists=True) -> Dict: @@ -33,7 +33,7 @@ def add_option(self, key: str, value: Any) -> Self: if nc.node is not None: raise ValueError(f"Key '{key}' already exists in the configuration") - self.__proc.set_value(key, value) + self.set_option(key, value) return self @@ -74,6 +74,11 @@ def set_option(self, key: str, value: Any) -> Self: self.delete_option(key) for i, item in enumerate(value): self.set_option(f"{key}[{i}]", item) + elif isinstance(value, Dict): + self.__proc.set_value(key, "dummy-value") + nodes: list[NodeCoords] = list(self.__proc.get_nodes(key)) + for nc in nodes: + nc.parent[nc.parentref] = value elif isinstance(value, str): # force quotes self.__proc.set_value(key, value, value_format=YAMLValueFormats.DQUOTE) @@ -111,8 +116,6 @@ def __init__(self, output: str, input: str | None = None) -> None: self.__yaml = YAML() self.__yaml.indent(sequence=4, offset=2) self.__yaml.preserve_quotes = True - self.__yaml.explicit_start = True - self.__yaml.explicit_end = True if input is not None: with open(input, mode="r") as f: diff --git a/util/trex_util.py b/util/trex_util.py new file mode 100644 index 0000000..8d4dc95 --- /dev/null +++ b/util/trex_util.py @@ -0,0 +1,96 @@ +import os +import subprocess +from enum import Enum +from pathlib import Path +from typing import Sequence, Tuple + +import pytest +from lbr_testsuite import executable + + +class TrexMode(Enum): + STL = (0,) + ASTF = (1,) + STF = 2 + + +PcapList = Sequence[Tuple[str, int | float]] + + +def send_to_remote(source: Path, hostname: str, destination: Path | None = None): + if destination is None: + destination = source + + subprocess.run( + [ + "rsync", + "-z", + "--checksum", + "--update", + str(source), + f"{os.environ['USER']}@{hostname}:{str(destination)}", + ] + ) + + +def mkdir_remote(dir: Path, hostname: str): + executor = executable.RemoteExecutor(host=hostname, user=os.environ["USER"]) + mkdir = executable.Tool( + f"mkdir -p {str(dir)} && chmod 777 {str(dir)}", + executor=executor, + sudo=True, + ) + mkdir.run() + + +def str_to_trex_mode(mode: str) -> TrexMode | None: + match mode.lower(): + case "astf": + return TrexMode.ASTF + case "stf": + return TrexMode.STF + case "stl": + return TrexMode.STL + case _: + return None + + +def get_trex_mode(request, available_modes) -> TrexMode: + """ + Selects a TRex mode out of `available_modes` based on the + `--prefer-trex-mode` and `--force-trex-mode` flags. + + `available_modes: List[TrexMode]` should be in descending order by priority. + + Automatically skips tests with no usable TRex modes. + """ + if ( + not isinstance(available_modes, list) + or len(available_modes) < 1 + or not isinstance(available_modes[0], TrexMode) + ): + raise ValueError("available_modes must be a list of at least one TrexMode") + + forced_mode = request.config.getoption("--force-trex-mode") + if forced_mode is not None: + mode_enum = str_to_trex_mode(forced_mode) + if mode_enum is None: + raise ValueError(f"{forced_mode} is not a valid TRex mode") + + if mode_enum in available_modes: + return mode_enum + else: + pytest.skip(f"{forced_mode} is not supported by this test") + + preferred_mode = request.config.getoption("--prefer-trex-mode") + if preferred_mode is not None: + mode_enum = str_to_trex_mode(preferred_mode) + if mode_enum is None: + raise ValueError(f"{forced_mode} is not a valid TRex mode") + + if mode_enum in available_modes: + return mode_enum + else: + return available_modes[0] + + return available_modes[0]