diff --git a/.github/workflows/test-notebooks.yml b/.github/workflows/test-notebooks.yml index dfe025d2..2a5cec6e 100644 --- a/.github/workflows/test-notebooks.yml +++ b/.github/workflows/test-notebooks.yml @@ -40,8 +40,8 @@ jobs: # Skip notebooks that require credentials or special setup case "$name" in - solve-on-oetc.ipynb|solve-on-remote.ipynb) - echo "Skipping $name (requires credentials or special setup)" + remote-machines.ipynb) + echo "Skipping $name (requires credentials or remote machine)" continue ;; esac diff --git a/doc/api.rst b/doc/api.rst index f0afc322..3c59ef09 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -519,10 +519,19 @@ Solvers Remote solving ============== +Solve a model on a remote machine via SSH or on the OET Cloud (OETC). +See :doc:`remote-machines` for usage. + .. autosummary:: :toctree: generated/ + remote.SSH + remote.SshSettings + remote.Oetc + remote.OetcSettings remote.RemoteHandler + remote.OetcHandler + remote.OetcCredentials Solver status and result types diff --git a/doc/index.rst b/doc/index.rst index 39846607..a31d645a 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -136,8 +136,7 @@ This package is published under MIT license. :maxdepth: 2 :caption: Solving - solve-on-remote - solve-on-oetc + remote-machines gpu-acceleration .. toctree:: diff --git a/doc/release_notes.rst b/doc/release_notes.rst index e5b7033f..23a8d91f 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -46,9 +46,40 @@ Most users should keep calling ``model.solve(...)``. If you want more control, y * Xpress now supports ``io_api="direct"``: the linopy model is loaded via the native ``loadproblem`` array API instead of being serialised through an LP/MPS file, with SOS constraints attached in-place. Adds ``model.to_xpress()`` matching the existing ``to_gurobipy`` / ``to_highspy`` / ``to_mosek`` helpers. * Writing the solution back to the model after solving is faster: it no longer rebuilds the constraint matrix, and now uses positional (rather than label-based) indexing — roughly 2× faster overall. +*Remote solves* + +* Pass ``remote=`` to ``Model.solve`` to run the solver on a remote worker: + + .. code-block:: python + + m.solve("gurobi", remote=OetcSettings(...), Method=2) + m.solve("highs", remote=SshSettings(hostname=...), presolve="on") + + ``solver_name`` and ``**solver_options`` work the same as for local solves; ``remote=`` selects *where* to run. After the call, ``model.remote`` holds the remote instance (mirrors :attr:`Model.solver`). +* ``SshSettings.setup_commands: list[str]`` — shell commands run on the remote before the solve, e.g. ``setup_commands=["conda activate linopy-env"]``. + **Deprecations** * ``Solver.solve_problem``, ``Solver.solve_problem_from_model``, and ``Solver.solve_problem_from_file`` still work but emit a ``DeprecationWarning``. Use ``Solver.from_name(...).solve()`` (or simply ``model.solve(...)``) instead. They will be removed in a future release. +* ``linopy.remote.OetcHandler`` and ``linopy.remote.RemoteHandler`` are deprecated. Construction emits a ``DeprecationWarning``; the ``solve_on_oetc`` / ``solve_on_remote`` return contracts are unchanged. Migrate: + + .. code-block:: python + + # Before + handler = OetcHandler( + OetcSettings(credentials=OetcCredentials(email=..., password=...), ...) + ) + solved = handler.solve_on_oetc(m, TimeLimit=100) + + # After + m.solve( + "gurobi", remote=OetcSettings(email=..., password=..., ...), TimeLimit=100 + ) + + Passing an existing handler via ``Model.solve(remote=handler, ...)`` is also deprecated — pass the settings dataclass instead. +* ``linopy.remote.OetcCredentials`` is deprecated. Pass ``email`` and ``password`` directly to :class:`OetcSettings` instead of wrapping them. The ``OetcSettings(credentials=OetcCredentials(...))`` shape still works for one deprecation cycle and emits a ``DeprecationWarning``. +* ``OetcSettings.solver`` and ``OetcSettings.solver_options`` are deprecated; pass the solver name and options to ``Model.solve(solver_name, remote=..., **options)`` or ``Oetc.submit(model, solver_name, **options)`` instead. During deprecation they are still honored — as a fallback when ``Model.solve(remote=...)`` is called without a ``solver_name``, and by the deprecated ``OetcHandler`` — and will be removed in a future release. +* :class:`linopy.remote.SSH` only exposes ``solve(...)``. For env activation use ``SshSettings.setup_commands``; for arbitrary remote shell commands, drop to :class:`RemoteHandler` (during deprecation) or paramiko directly. **Bug Fixes** @@ -60,6 +91,7 @@ Most users should keep calling ``model.solve(...)``. If you want more control, y * ``available_solvers`` now lists all *installed* solvers, even ones without a working license. If you used it to decide "can I actually solve with X?", switch to ``linopy.licensed_solvers`` or ``SolverClass.license_status()``. * ``Model.solver_model`` and ``Model.solver_name`` are now read-only properties that delegate to ``model.solver``. You can't reassign them (only ``= None`` is allowed, which closes the solver), and ``solver_name`` is ``None`` before the first solve. * ``result.solution.primal`` and ``result.solution.dual`` are now ``numpy`` arrays indexed by linopy's integer labels (with ``NaN`` for slots without a value), instead of pandas Series keyed by variable/constraint name. If you accessed them by name, use ``model.variables[name].solution`` (or ``model.constraints[name].dual``) instead. +* The pip extra ``linopy[remote]`` has been renamed to ``linopy[ssh]`` to match what it installs (only ``paramiko``, for SSH transport — OETC has its own ``linopy[oetc]`` extra). ``linopy[remote]`` no longer exists; update your install commands. **Internal** diff --git a/doc/remote-machines.nblink b/doc/remote-machines.nblink new file mode 100644 index 00000000..f273fb0c --- /dev/null +++ b/doc/remote-machines.nblink @@ -0,0 +1,3 @@ +{ + "path": "../examples/remote-machines.ipynb" +} diff --git a/doc/solve-on-oetc.nblink b/doc/solve-on-oetc.nblink deleted file mode 100644 index ab7ed00c..00000000 --- a/doc/solve-on-oetc.nblink +++ /dev/null @@ -1,3 +0,0 @@ -{ - "path": "../examples/solve-on-oetc.ipynb" -} diff --git a/doc/solve-on-remote.nblink b/doc/solve-on-remote.nblink deleted file mode 100644 index 03be52c0..00000000 --- a/doc/solve-on-remote.nblink +++ /dev/null @@ -1,3 +0,0 @@ -{ - "path": "../examples/solve-on-remote.ipynb" -} diff --git a/doc/user-guide.rst b/doc/user-guide.rst index 8b7ee5bd..ce4549c3 100644 --- a/doc/user-guide.rst +++ b/doc/user-guide.rst @@ -53,8 +53,8 @@ Where to go next :doc:`piecewise-linear-constraints`, and the :doc:`testing-framework` for asserting structural properties of a model. -- **Solving** — :doc:`solve-on-remote` (SSH), - :doc:`solve-on-oetc` (OET Cloud), :doc:`gpu-acceleration` (cuPDLPx). +- **Solving** — :doc:`remote-machines` (SSH or OET Cloud), + :doc:`gpu-acceleration` (cuPDLPx). - **Troubleshooting** — :doc:`infeasible-model` (diagnosing infeasible problems), :doc:`gurobi-double-logging` (and other solver quirks). - **Reference** — the full :doc:`api` listing. diff --git a/examples/remote-machines.ipynb b/examples/remote-machines.ipynb new file mode 100644 index 00000000..0c8fa0c8 --- /dev/null +++ b/examples/remote-machines.ipynb @@ -0,0 +1,259 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Remote machines\n", + "\n", + "linopy can ship your model to a remote machine, run a solver there, and pull the solved model back. Two remotes are supported, differing in *who runs the machine*:\n", + "\n", + "- **SSH** — a server you run yourself, reached over SSH.\n", + "- **OETC** — [OET Cloud](https://open-energy-transition.org/), a managed optimization service that runs the machine for you.\n", + "\n", + "Both share the same entry point on `Model.solve`:\n", + "\n", + "```python\n", + "m.solve(\"gurobi\", remote=, **solver_options)\n", + "```\n", + "\n", + "`solver_name` and `**solver_options` work exactly like a local solve; `remote=` selects *where* to run. After the call, `model.remote` holds the remote instance for post-solve introspection (mirrors `model.solver`)." + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "> **Note:** This notebook is not executed during the documentation build — it requires either SSH access to a remote server or OETC credentials." + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Create a model\n", + "\n", + "Build the model locally as usual:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "from numpy import arange\n", + "from xarray import DataArray\n", + "\n", + "from linopy import Model\n", + "\n", + "N = 10\n", + "m = Model()\n", + "coords = [arange(N), arange(N)]\n", + "x = m.add_variables(coords=coords, name=\"x\")\n", + "y = m.add_variables(coords=coords, name=\"y\")\n", + "m.add_constraints(x - y >= DataArray(arange(N)))\n", + "m.add_constraints(x + y >= 0)\n", + "m.add_objective((2 * x + y).sum())\n", + "m" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## SSH\n", + "\n", + "**What you need**\n", + "\n", + "- `uv pip install \"linopy[ssh]\"` locally (pulls in `paramiko`).\n", + "- A remote server with linopy and a solver installed (e.g. in a conda environment).\n", + "- SSH access to that machine (key-based auth recommended).\n", + "\n", + "Build an `SshSettings` and pass it as `remote=`. Use `setup_commands` to activate environments or export variables on the remote shell before the solve." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from linopy.remote import SshSettings\n", + "\n", + "ssh_settings = SshSettings(\n", + " hostname=\"your.host.de\",\n", + " username=\"username\",\n", + " # password=\"...\", # not needed when SSH keys are autodetected\n", + " setup_commands=[\"conda activate linopy-env\"],\n", + ")\n", + "\n", + "m.solve(\"gurobi\", remote=ssh_settings)\n", + "m.solution" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## OETC\n", + "\n", + "**What you need**\n", + "\n", + "- `uv pip install \"linopy[oetc]\"` locally (pulls in `google-cloud-storage` and `requests`).\n", + "- An OETC account with valid credentials.\n", + "- The OETC authentication and orchestrator server URLs.\n", + "\n", + "Build an `OetcSettings`. Two construction styles:\n", + "\n", + "1. **Manually** — pass `email`, `password`, `name`, and the server URLs.\n", + "2. **`OetcSettings.from_env()`** — resolve everything from environment variables (`OETC_EMAIL`, `OETC_PASSWORD`, `OETC_NAME`, `OETC_AUTH_URL`, `OETC_ORCHESTRATOR_URL`). Recommended for CI/CD. Keyword arguments override the environment.\n", + "\n", + "linopy uploads the model to OETC, submits a compute job, polls until it finishes, and downloads the solution — all behind one call." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "from linopy.remote import OetcSettings\n", + "\n", + "# Option 1: pass credentials directly\n", + "oetc_settings = OetcSettings(\n", + " email=\"your-email@example.com\",\n", + " password=\"your-password\",\n", + " name=\"linopy-example-job\",\n", + " authentication_server_url=\"https://auth.oetcloud.com\",\n", + " orchestrator_server_url=\"https://orchestrator.oetcloud.com\",\n", + " cpu_cores=4,\n", + " disk_space_gb=20,\n", + ")\n", + "\n", + "# Option 2: load from environment (with optional overrides)\n", + "# oetc_settings = OetcSettings.from_env(cpu_cores=4, disk_space_gb=20)\n", + "\n", + "m.solve(\"gurobi\", remote=oetc_settings, TimeLimit=600, MIPGap=0.01)\n", + "\n", + "print(f\"Status: {m.status}\")\n", + "print(f\"Objective: {m.objective.value:.4f}\")\n", + "m.solution" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "## Advanced: remote objects\n", + "\n", + "`Model.solve(remote=...)` builds a throwaway `Oetc` / `SSH` object for each call. You can also create one yourself and keep it — both are **reusable remote objects**: set one up once, then solve any number of models. `SSH` holds one SSH connection open; `Oetc` authenticates once and reuses the token. OETC additionally exposes an async `submit` / `status` / `collect` seam." + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + ".. important::\n", + " ``Oetc`` and ``SSH`` return a *new* solved ``Model`` — they never modify the model you pass in. Use the returned object, or call ``Model.solve(remote=...)`` for in-place writeback." + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "### SSH\n", + "\n", + "`SSH(settings)` holds the SSH connection open. Reuse one instance to solve several models without reconnecting or re-running `setup_commands` each time:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "from linopy.remote import SSH\n", + "\n", + "# One SSH object keeps one connection open. Reuse it across models\n", + "# so setup_commands and the SSH handshake run only once.\n", + "ssh = SSH(ssh_settings)\n", + "\n", + "models = [m] # replace with your own models\n", + "solved = [ssh.solve(model, \"gurobi\", presolve=\"on\") for model in models]\n", + "\n", + "# `m` is untouched — the solution is on the returned model.\n", + "solved[0].solution" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "### OETC\n", + "\n", + "`Oetc` adds an async seam on top of session reuse:\n", + "\n", + "1. `submit(model, solver_name, **options)` — upload and dispatch; returns a job uuid.\n", + "2. `status(job_uuid)` — a single, non-blocking status check.\n", + "3. `collect(job_uuid)` — wait for completion, download, and return the solved model.\n", + "\n", + "A job is identified solely by its uuid string: submit many models, hold their uuids, and collect each when convenient — even from a different process, where a fresh `Oetc(settings)` re-authenticates and collects by uuid." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "from concurrent.futures import ThreadPoolExecutor\n", + "\n", + "from linopy.remote import Oetc\n", + "\n", + "oetc = Oetc(oetc_settings)\n", + "\n", + "# A single job: submit now, collect by uuid later (even in another process).\n", + "job_uuid = oetc.submit(m, \"gurobi\", TimeLimit=600, MIPGap=0.01)\n", + "print(f\"Submitted job {job_uuid}\")\n", + "solved = oetc.collect(job_uuid)\n", + "print(f\"Status: {solved.status}\")\n", + "\n", + "# Many models on one session: submit all, then collect concurrently.\n", + "models = [m] # replace with your own models\n", + "uuids = [oetc.submit(model, \"gurobi\") for model in models]\n", + "with ThreadPoolExecutor() as pool:\n", + " solved_models = list(pool.map(oetc.collect, uuids))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "nbsphinx": { + "execute": "never" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/solve-on-oetc.ipynb b/examples/solve-on-oetc.ipynb deleted file mode 100644 index 28e1c04d..00000000 --- a/examples/solve-on-oetc.ipynb +++ /dev/null @@ -1,431 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Solve on OETC (OET Cloud)\n", - "\n", - "This example demonstrates how to use linopy with OETC (OET Cloud) for cloud-based optimization solving. OETC is a cloud platform that provides scalable computing resources for optimization problems.\n", - "\n", - "## What you need to run this example:\n", - "\n", - "* A working installation of the required packages:\n", - " * `pip install google-cloud-storage requests`\n", - "* An OETC account with valid credentials (email and password)\n", - "* Access to OETC authentication and orchestrator servers\n", - "\n", - "## How OETC Cloud Solving Works\n", - "\n", - "The OETC integration follows this workflow:\n", - "\n", - "1. **Model Creation**: Define your optimization model locally using linopy\n", - "2. **Authentication**: Sign in to the OETC platform using your credentials\n", - "3. **File Upload**: Compress and upload your model to Google Cloud Storage\n", - "4. **Job Submission**: Submit a compute job to the OETC orchestrator\n", - "5. **Job Monitoring**: Wait for job completion with automatic status polling\n", - "6. **Solution Download**: Download and decompress the solved model\n", - "7. **Local Integration**: Load the solution back into your local model\n", - "\n", - "All of these steps are handled automatically by linopy's `OetcHandler`." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "> **Note:** This notebook requires Google Cloud credentials and access to the OETC platform. It is not executed during the documentation build, so no cell outputs are shown. To run it yourself, install the `linopy[oetc]` extra and configure your credentials." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create a Model\n", - "\n", - "First, let's create an optimization model that we want to solve on OETC:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from numpy import arange\n", - "from xarray import DataArray\n", - "\n", - "from linopy import Model\n", - "\n", - "# Create a medium-sized optimization problem\n", - "N = 50\n", - "m = Model()\n", - "\n", - "# Define decision variables with coordinates\n", - "coords = [arange(N), arange(N)]\n", - "x = m.add_variables(coords=coords, name=\"x\", lower=0)\n", - "y = m.add_variables(coords=coords, name=\"y\", lower=0)\n", - "\n", - "# Add constraints\n", - "m.add_constraints(x - y >= DataArray(arange(N)), name=\"constraint1\")\n", - "m.add_constraints(x + y >= DataArray(arange(N) * 0.5), name=\"constraint2\")\n", - "m.add_constraints(x <= DataArray(arange(N) + 10), name=\"upper_bounds\")\n", - "\n", - "# Set objective function\n", - "m.add_objective((2 * x + y).sum())\n", - "\n", - "print(\n", - " f\"Model created with {len(m.variables)} variable groups and {len(m.constraints)} constraint groups\"\n", - ")\n", - "m" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Configure OETC Settings\n", - "\n", - "There are two ways to configure OETC settings:\n", - "\n", - "1. **Manual construction** — build `OetcCredentials` and `OetcSettings` explicitly\n", - "2. **`OetcSettings.from_env()`** — resolve credentials and options from environment variables\n", - "\n", - "### Option 1: Manual Construction" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Configure your OETC credentials\n", - "# IMPORTANT: Never hardcode credentials in production code!\n", - "# Use environment variables or secure credential management\n", - "import os\n", - "\n", - "from linopy.remote.oetc import (\n", - " ComputeProvider,\n", - " OetcCredentials,\n", - " OetcHandler,\n", - " OetcSettings,\n", - ")\n", - "\n", - "credentials = OetcCredentials(\n", - " email=os.getenv(\"OETC_EMAIL\", \"your-email@example.com\"),\n", - " password=os.getenv(\"OETC_PASSWORD\", \"your-password\"),\n", - ")\n", - "\n", - "# Configure OETC settings\n", - "settings = OetcSettings(\n", - " credentials=credentials,\n", - " name=\"linopy-example-job\",\n", - " authentication_server_url=\"https://auth.oetcloud.com\", # Replace with actual URL\n", - " orchestrator_server_url=\"https://orchestrator.oetcloud.com\", # Replace with actual URL\n", - " compute_provider=ComputeProvider.GCP,\n", - " cpu_cores=4, # Number of CPU cores to allocate\n", - " disk_space_gb=20, # Disk space in GB\n", - " delete_worker_on_error=False, # Keep worker for debugging if job fails\n", - ")\n", - "\n", - "print(\"OETC settings configured successfully\")\n", - "print(f\"Solver: {settings.solver}\")\n", - "print(f\"CPU cores: {settings.cpu_cores}\")\n", - "print(f\"Disk space: {settings.disk_space_gb} GB\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Option 2: Create Settings from Environment Variables\n", - "\n", - "`OetcSettings.from_env()` reads configuration from environment variables,\n", - "with optional keyword overrides. This is the recommended approach for\n", - "CI/CD pipelines and production deployments.\n", - "\n", - "| Environment Variable | Required | Description |\n", - "|---|---|---|\n", - "| `OETC_EMAIL` | Yes | Account email |\n", - "| `OETC_PASSWORD` | Yes | Account password |\n", - "| `OETC_NAME` | Yes | Job name |\n", - "| `OETC_AUTH_URL` | Yes | Authentication server URL |\n", - "| `OETC_ORCHESTRATOR_URL` | Yes | Orchestrator server URL |\n", - "| `OETC_CPU_CORES` | No | CPU cores (default: 2) |\n", - "| `OETC_DISK_SPACE_GB` | No | Disk space in GB (default: 10) |\n", - "| `OETC_DELETE_WORKER_ON_ERROR` | No | Delete worker on error (default: false) |\n", - "\n", - "Keyword arguments take precedence over environment variables." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Create settings from environment variables\n", - "# All required env vars must be set: OETC_EMAIL, OETC_PASSWORD,\n", - "# OETC_NAME, OETC_AUTH_URL, OETC_ORCHESTRATOR_URL\n", - "settings = OetcSettings.from_env()\n", - "\n", - "# Or override specific values via keyword arguments\n", - "settings = OetcSettings.from_env(\n", - " cpu_cores=8,\n", - " disk_space_gb=50,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Initialize OETC Handler\n", - "\n", - "The `OetcHandler` manages the entire cloud solving process:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Initialize the OETC handler\n", - "# This will authenticate with OETC and fetch cloud provider credentials\n", - "oetc_handler = OetcHandler(settings)\n", - "\n", - "print(\"OETC handler initialized successfully\")\n", - "print(f\"Authentication token expires at: {oetc_handler.jwt.expires_at}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Solve the Model on OETC\n", - "\n", - "Now we can solve our model on the OETC cloud platform. The `OetcHandler` is passed to the model's `solve()` method:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Solve the model on OETC\n", - "# This will upload the model, submit a job, wait for completion, and download the solution\n", - "import time\n", - "\n", - "print(\"Starting cloud solving process...\")\n", - "start_time = time.time()\n", - "\n", - "try:\n", - " status, termination_condition = m.solve(remote=oetc_handler, solver_name=\"highs\")\n", - "\n", - " end_time = time.time()\n", - " total_time = end_time - start_time\n", - "\n", - " print(f\"\\nSolving completed in {total_time:.2f} seconds\")\n", - " print(f\"Status: {status}\")\n", - " print(f\"Termination condition: {termination_condition}\")\n", - " print(f\"Objective value: {m.objective.value:.4f}\")\n", - "\n", - "except Exception as e:\n", - " print(f\"Error during solving: {e}\")\n", - " raise" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Examine the Solution\n", - "\n", - "Let's examine the solution returned from OETC:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Display solution summary\n", - "print(f\"Model status: {m.status}\")\n", - "print(f\"Objective value: {m.objective.value}\")\n", - "print(f\"Number of variables: {m.solution.sizes}\")\n", - "\n", - "# Show a subset of the solution\n", - "print(\"\\nSample of solution values:\")\n", - "print(\"x values (first 5x5):\")\n", - "print(m.solution[\"x\"].isel(dim_0=slice(0, 5), dim_1=slice(0, 5)).values)\n", - "\n", - "print(\"\\ny values (first 5x5):\")\n", - "print(m.solution[\"y\"].isel(dim_0=slice(0, 5), dim_1=slice(0, 5)).values)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Advanced OETC Configuration\n", - "\n", - "### Solver Options\n", - "\n", - "Solver name and options can be configured at two levels:\n", - "\n", - "1. **Settings level** — defaults stored in `OetcSettings.solver` and `OetcSettings.solver_options`\n", - "2. **Call level** — passed via `m.solve(solver_name=..., **solver_options)`\n", - "\n", - "Call-level options **override** settings-level options. The two dicts are\n", - "merged (call-time takes precedence), and the original settings are never\n", - "mutated." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Settings-level defaults\n", - "advanced_settings = OetcSettings(\n", - " credentials=credentials,\n", - " name=\"advanced-linopy-job\",\n", - " authentication_server_url=\"https://auth.oetcloud.com\",\n", - " orchestrator_server_url=\"https://orchestrator.oetcloud.com\",\n", - " solver=\"gurobi\",\n", - " solver_options={\n", - " \"TimeLimit\": 600,\n", - " \"MIPGap\": 0.01,\n", - " },\n", - " cpu_cores=8,\n", - " disk_space_gb=50,\n", - ")\n", - "\n", - "advanced_handler = OetcHandler(advanced_settings)\n", - "\n", - "# Call-level overrides: solver_name and solver_options are forwarded\n", - "# to OETC and merged with the settings defaults.\n", - "# Here MIPGap from settings (0.01) is kept, TimeLimit is overridden to 300.\n", - "status, condition = m.solve(\n", - " remote=advanced_handler,\n", - " solver_name=\"gurobi\",\n", - " TimeLimit=300,\n", - " Threads=4,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Error Handling and Debugging\n", - "\n", - "When working with cloud solving, it's important to handle potential errors gracefully:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def solve_with_error_handling(model, oetc_handler, max_retries=3):\n", - " \"\"\"Solve model with error handling and retries\"\"\"\n", - "\n", - " for attempt in range(max_retries):\n", - " try:\n", - " print(f\"Solving attempt {attempt + 1}/{max_retries}...\")\n", - " status, termination = model.solve(remote=oetc_handler)\n", - "\n", - " if status == \"ok\":\n", - " print(\"Solving successful!\")\n", - " return status, termination\n", - " else:\n", - " print(f\"Solving returned status: {status}\")\n", - "\n", - " except Exception as e:\n", - " print(f\"Attempt {attempt + 1} failed: {e}\")\n", - "\n", - " if attempt < max_retries - 1:\n", - " print(\"Retrying in 30 seconds...\")\n", - " time.sleep(30)\n", - " else:\n", - " print(\"All attempts failed\")\n", - " raise\n", - "\n", - " return None, None\n", - "\n", - "\n", - "# Example usage (commented out to avoid actual execution)\n", - "# status, termination = solve_with_error_handling(m, oetc_handler)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Security Best Practices\n", - "\n", - "When using OETC in production:\n", - "\n", - "1. **Never hardcode credentials**: Use environment variables or secure credential stores\n", - "2. **Use token expiration**: The OETC handler automatically manages token expiration\n", - "3. **Validate inputs**: Ensure your model data doesn't contain sensitive information\n", - "4. **Monitor costs**: Cloud computing resources have associated costs\n", - "5. **Clean up resources**: Set `delete_worker_on_error=True` for automatic cleanup\n", - "\n", - "## Comparison with SSH Remote Solving\n", - "\n", - "| Feature | OETC Cloud | SSH Remote |\n", - "|---------|------------|------------|\n", - "| Setup | Account registration | Server access required |\n", - "| Scalability | Auto-scaling | Fixed server resources |\n", - "| Maintenance | Managed service | Self-managed |\n", - "| Cost | Pay-per-use | Infrastructure costs |\n", - "| Security | Enterprise-grade | Self-managed |\n", - "| Solver Licenses | Included | User-provided |\n", - "\n", - "Choose OETC for:\n", - "- Large-scale problems requiring significant compute resources\n", - "- Temporary or intermittent optimization needs\n", - "- Teams without dedicated infrastructure\n", - "- Access to premium solvers without license management\n", - "\n", - "Choose SSH remote for:\n", - "- Existing infrastructure with optimization solvers\n", - "- Strict data governance requirements\n", - "- Consistent, long-running optimization workloads\n", - "- Full control over the solving environment" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - }, - "nbsphinx": { - "execute": "never" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/solve-on-remote.ipynb b/examples/solve-on-remote.ipynb deleted file mode 100644 index 73e6346b..00000000 --- a/examples/solve-on-remote.ipynb +++ /dev/null @@ -1,655 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "4db583af", - "metadata": {}, - "source": [ - "# Remote Solving with SSH\n", - "\n", - "This example demonstrates how linopy can solve optimization models on remote machines using SSH connections. This is one of two remote solving options available in linopy:\n", - "\n", - "1. **SSH Remote Solving** (this example) - Connect to your own servers via SSH\n", - "2. **OETC Cloud Solving** - Use cloud-based optimization services (see [OETC notebook](solve-on-oetc.ipynb))\n", - "\n", - "## SSH Remote Solving\n", - "\n", - "SSH remote solving is ideal when you have:\n", - "\n", - "* Access to dedicated servers with optimization solvers installed\n", - "* Full control over the computing environment\n", - "* Existing infrastructure for optimization workloads\n", - "\n", - "## What you need for SSH remote solving\n", - "\n", - "* The `remote` extra installed on your local machine (`uv pip install \"linopy[remote]\"`), which pulls in `paramiko`\n", - "* A remote server with a working installation of linopy (e.g., in a conda environment)\n", - "* SSH access to that machine\n", - "\n", - "## How SSH Remote Solving Works\n", - "\n", - "The workflow consists of the following steps, most of which linopy handles automatically:\n", - "\n", - "1. Define a model on the local machine\n", - "2. Save the model on the remote machine via SSH\n", - "3. Load, solve and write out the model on the remote machine\n", - "4. Copy the solved model back to the local machine\n", - "5. Load the solved model on the local machine\n", - "\n", - "The model initialization happens locally, while the actual solving happens remotely.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "> **Note:** This notebook requires SSH access to a remote server with a solver installed. It is not executed during the documentation build, so no cell outputs are shown. To run it yourself, configure SSH access and install a solver on the remote machine." - ] - }, - { - "cell_type": "markdown", - "id": "together-ocean", - "metadata": {}, - "source": [ - "## Create a model\n", - "\n", - "First we are going to build the optimization model we want to solve in our local process." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "dramatic-cannon", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Linopy LP model\n", - "===============\n", - "\n", - "Variables:\n", - "----------\n", - " * x (dim_0, dim_1)\n", - " * y (dim_0, dim_1)\n", - "\n", - "Constraints:\n", - "------------\n", - " * con0 (dim_0, dim_1)\n", - " * con1 (dim_0, dim_1)\n", - "\n", - "Status:\n", - "-------\n", - "initialized" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from numpy import arange\n", - "from xarray import DataArray\n", - "\n", - "from linopy import Model\n", - "\n", - "N = 10\n", - "m = Model()\n", - "coords = [arange(N), arange(N)]\n", - "x = m.add_variables(coords=coords, name=\"x\")\n", - "y = m.add_variables(coords=coords, name=\"y\")\n", - "m.add_constraints(x - y >= DataArray(arange(N)))\n", - "m.add_constraints(x + y >= 0)\n", - "m.add_objective((2 * x + y).sum())\n", - "m" - ] - }, - { - "cell_type": "markdown", - "id": "0f9e9b09", - "metadata": {}, - "source": [ - "## Initialize SSH connection\n", - "\n", - "Now we have to set up the SSH connection. The SSH connection is handled by the `RemoteHandler` class in of the `linopy.remote` module. This is strongly relying on the `paramiko` package. When initializing, you have two options:\n", - "\n", - "1. Pass the standard arguments `host`, `username`. If the SSH keys are stored in a default location, the keys are autodetected and the `RemoteHandler` does not require the `password` argument. Otherwise you also have to pass the password.\n", - "2. Pass a working `paramiko.SSHClient` as `client`. This enables you to set up the SSH connection by others means supported by `paramiko`. \n", - "\n", - "In the following we use the first option." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "protecting-power", - "metadata": {}, - "outputs": [], - "source": [ - "from linopy import RemoteHandler\n", - "\n", - "host = \"your.host.de\"\n", - "username = \"username\"\n", - "\n", - "handler = RemoteHandler(host, username=username)" - ] - }, - { - "cell_type": "markdown", - "id": "featured-maria", - "metadata": {}, - "source": [ - "## Optionally: Activate a conda environment on the remote \n", - "\n", - "The `RemoteHandler` keeps an interactive shell in the background. You can execute any code in order to prepare the solving process (install linopy, activate an environment). \n", - "\n", - "Assuming you have a conda environment `linopy-env` that contains the `linopy` package with dependencies, you can run " - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "virtual-anxiety", - "metadata": {}, - "outputs": [], - "source": [ - "handler.execute(\"conda activate linopy-env\")" - ] - }, - { - "cell_type": "markdown", - "id": "sonic-rebate", - "metadata": {}, - "source": [ - "## Solve the model on remote\n", - "\n", - "Now the only thing you have to do is to pass the `RemoteHandler` as an argument to the `solve` function. Other keyword arguments like `solver_name` and solver options are propagated to the remote machine. " - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "ongoing-desktop", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Set parameter Username\n", - "Academic license - for non-commercial use only - expires 2023-02-06\n", - "Read LP format model from file /tmp/linopy-problem-uh4gvjyp.lp\n", - "Reading time = 0.00 seconds\n", - "obj: 200 rows, 200 columns, 400 nonzeros\n", - "Gurobi Optimizer version 9.5.1 build v9.5.1rc2 (linux64)\n", - "Thread count: 12 physical cores, 24 logical processors, using up to 24 threads\n", - "Optimize a model with 200 rows, 200 columns and 400 nonzeros\n", - "Model fingerprint: 0xf2bcac49\n", - "Coefficient statistics:\n", - "Matrix range [1e+00, 1e+00]\n", - "Objective range [1e+00, 2e+00]\n", - "Bounds range [0e+00, 0e+00]\n", - "RHS range [1e+00, 9e+00]\n", - "Presolve removed 200 rows and 200 columns\n", - "Presolve time: 0.00s\n", - "Presolve: All rows and columns removed\n", - "Iteration Objective Primal Inf. Dual Inf. Time\n", - "0 2.2500000e+02 0.000000e+00 0.000000e+00 0s\n", - "\n", - "Solved in 0 iterations and 0.00 seconds (0.00 work units)\n", - "Optimal objective 2.250000000e+02\n" - ] - }, - { - "data": { - "text/plain": [ - "('ok', '')" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "m.solve(remote=handler)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "sustained-portrait", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.Dataset>\n",
-       "Dimensions:  (dim_0: 10, dim_1: 10)\n",
-       "Coordinates:\n",
-       "  * dim_0    (dim_0) int64 0 1 2 3 4 5 6 7 8 9\n",
-       "  * dim_1    (dim_1) int64 0 1 2 3 4 5 6 7 8 9\n",
-       "Data variables:\n",
-       "    x        (dim_0, dim_1) float64 0.0 0.0 0.0 0.0 0.0 ... 4.5 4.5 4.5 4.5 4.5\n",
-       "    y        (dim_0, dim_1) float64 0.0 0.0 0.0 0.0 0.0 ... -4.5 -4.5 -4.5 -4.5
" - ], - "text/plain": [ - "\n", - "Dimensions: (dim_0: 10, dim_1: 10)\n", - "Coordinates:\n", - " * dim_0 (dim_0) int64 0 1 2 3 4 5 6 7 8 9\n", - " * dim_1 (dim_1) int64 0 1 2 3 4 5 6 7 8 9\n", - "Data variables:\n", - " x (dim_0, dim_1) float64 0.0 0.0 0.0 0.0 0.0 ... 4.5 4.5 4.5 4.5 4.5\n", - " y (dim_0, dim_1) float64 0.0 0.0 0.0 0.0 0.0 ... -4.5 -4.5 -4.5 -4.5" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "m.solution" - ] - } - ], - "metadata": { - "@webio": { - "lastCommId": null, - "lastKernelId": null - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.3" - }, - "nbsphinx": { - "execute": "never" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/linopy/model.py b/linopy/model.py index 48a8200b..06845272 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -79,11 +79,13 @@ add_piecewise_formulation, ) from linopy.remote import RemoteHandler +from linopy.remote.ssh import SshSettings try: - from linopy.remote import OetcHandler + from linopy.remote import OetcHandler, OetcSettings except ImportError: OetcHandler = None # type: ignore + OetcSettings = None # type: ignore from linopy.solver_capabilities import solver_supports from linopy.solvers import ( IO_APIS, @@ -94,6 +96,7 @@ SOSReformulationResult, reformulate_sos_constraints, sos_reformulation_context, + suppress_serialization_warning, undo_sos_reformulation, ) from linopy.types import ( @@ -111,6 +114,14 @@ logger = logging.getLogger(__name__) +# Types accepted as ``remote=`` for the standalone-class dispatch in +# :meth:`Model.solve` (as opposed to the legacy ``OetcHandler`` / +# ``RemoteHandler`` deprecation path). The OETC entry is conditional on +# the optional google-cloud / requests deps being available. +_REMOTE_SETTINGS_TYPES: tuple[type, ...] = (SshSettings,) +if OetcSettings is not None: + _REMOTE_SETTINGS_TYPES = (*_REMOTE_SETTINGS_TYPES, OetcSettings) + def _coords_to_dict( coords: Sequence[Sequence | pd.Index | DataArray] | Mapping, @@ -196,6 +207,7 @@ class Model: """ _solver: solvers.Solver | None + _remote: Any _variables: Variables _constraints: Constraints _objective: Objective @@ -243,6 +255,7 @@ class Model: "_relaxed_registry", "_piecewise_formulations", "_solver", + "_remote", "_sos_reformulation_state", "__weakref__", ) @@ -314,6 +327,7 @@ def __init__( gettempdir() if solver_dir is None else solver_dir ) self._solver: solvers.Solver | None = None + self._remote: Any = None self._sos_reformulation_state: SOSReformulationResult | None = None @property @@ -326,6 +340,24 @@ def solver(self, value: solvers.Solver | None) -> None: self._solver.close() self._solver = value + @property + def remote(self) -> Any: + """ + Standalone remote instance (``Oetc`` / ``SSH``) from the most recent + solve, or ``None``. + + Set by :meth:`solve` when called with ``remote=``; lets + callers reuse the authenticated connection after the solve (e.g. + ``model.remote.submit(...)`` for further OETC jobs). ``None`` for + local solves and after a legacy ``remote=OetcHandler/RemoteHandler`` + solve. + """ + return self._remote + + @remote.setter + def remote(self, value: Any) -> None: + self._remote = value + @property def solver_model(self) -> Any: return self.solver.solver_model if self.solver is not None else None @@ -1594,7 +1626,7 @@ def solve( sanitize_zeros: bool = True, sanitize_infinities: bool = True, slice_size: int = 2_000_000, - remote: RemoteHandler | OetcHandler | None = None, + remote: RemoteHandler | OetcHandler | OetcSettings | SshSettings | None = None, progress: bool | None = None, mock_solve: bool = False, reformulate_sos: bool | Literal["auto"] = False, @@ -1699,50 +1731,37 @@ def solve( f"Keyword argument `io_api` has to be one of {IO_APIS} or None" ) - if remote is not None: - # The remote branch short-circuits before reaching Solver.solve(), - # which is where the empty-objective check normally fires. Replicate - # it here. This duplication becomes obsolete once OETC is folded - # into the Solver pipeline (see PyPSA/linopy#683). - if self.objective.expression.empty: - raise ValueError( - "No objective has been set on the model. Use " - "`m.add_objective(...)` first (e.g. `m.add_objective(0 * x)` " - "for a pure feasibility problem)." - ) - if isinstance(remote, OetcHandler): - solved = remote.solve_on_oetc( - self, - solver_name=solver_name, - reformulate_sos=reformulate_sos, - **solver_options, - ) - else: - solved = remote.solve_on_remote( - self, - solver_name=solver_name, - io_api=io_api, - problem_fn=problem_fn, - solution_fn=solution_fn, - log_fn=log_fn, - basis_fn=basis_fn, - warmstart_fn=warmstart_fn, - keep_files=keep_files, - sanitize_zeros=sanitize_zeros, - reformulate_sos=reformulate_sos, - **solver_options, - ) + # New standalone Oetc / SSH remote handlers are selected by passing + # their settings dataclass via ``remote=``. ``solver_name`` and + # ``**solver_options`` describe the *inner* solver to run on the + # worker. + if isinstance(remote, _REMOTE_SETTINGS_TYPES): + return self._solve_with_remote_settings( + remote, + inner_solver=solver_name, + solver_options=solver_options, + reformulate_sos=reformulate_sos, + ) - if solved.objective.value is not None: - self.objective.set_value(float(solved.objective.value)) - self.status = solved.status - self.termination_condition = solved.termination_condition - for k, v in self.variables.items(): - v.solution = solved.variables[k].solution - for k, c in self.constraints.items(): - if "dual" in solved.constraints[k]: - c.dual = solved.constraints[k].dual - return self.status, self.termination_condition + if remote is not None: + # Back-compat shim: the legacy ``remote=OetcHandler/RemoteHandler`` + # shape pre-dates the standalone Oetc/SSH classes. Route to the + # new entrypoint and warn. Slated for removal once one release of + # overlap has shipped. + return self._solve_via_legacy_remote( + remote, + solver_name=solver_name, + io_api=io_api, + problem_fn=problem_fn, + solution_fn=solution_fn, + log_fn=log_fn, + basis_fn=basis_fn, + warmstart_fn=warmstart_fn, + keep_files=keep_files, + sanitize_zeros=sanitize_zeros, + reformulate_sos=reformulate_sos, + solver_options=solver_options, + ) if len(available_solvers) == 0: raise RuntimeError("No solver installed.") @@ -1827,6 +1846,158 @@ def solve( return self.assign_result(result) + def _solve_with_remote_settings( + self, + settings: Any, + *, + inner_solver: str | None, + solver_options: dict[str, Any], + reformulate_sos: bool | Literal["auto"], + ) -> tuple[str, str]: + """ + Dispatch a remote solve from an ``OetcSettings`` / ``SshSettings`` instance. + + The standalone remote classes (``Oetc``, ``SSH`` in + :mod:`linopy.remote`) are *not* :class:`linopy.solvers.Solver` + subclasses — they're a parallel concept. The instance is attached + to :attr:`Model.remote` after the call so callers can reuse the + authenticated connection. + """ + effective_inner: str | None = inner_solver + effective_options: dict[str, Any] = solver_options + if OetcSettings is not None and isinstance(settings, OetcSettings): + from linopy.remote.oetc import Oetc + + remote_cls: Any = Oetc + # Deprecated fallback to `OetcSettings.solver` / `solver_options` + # when `Model.solve` is called without a `solver_name`. + effective_inner = inner_solver or settings.solver + effective_options = {**(settings.solver_options or {}), **solver_options} + elif isinstance(settings, SshSettings): + from linopy.remote.ssh import SSH + + remote_cls = SSH + else: + raise TypeError( # pragma: no cover — checked by _REMOTE_SETTINGS_TYPES + f"Unknown remote settings type: {type(settings).__name__}" + ) + + if not effective_inner: + raise ValueError( + f"`m.solve(remote=<{type(settings).__name__}>)` requires " + "an explicit `solver_name=` for the solver to run " + "on the worker." + ) + + if self.objective.expression.empty: + raise ValueError( + "No objective has been set on the model. Use " + "`m.add_objective(...)` first (e.g. `m.add_objective(0 * x)` " + "for a pure feasibility problem)." + ) + + # Apply SOS reformulation before the remote handler serializes the + # model; the worker just solves a plain MILP, the lifecycle stays + # on this Model. ``sos_reformulation_context`` handles the + # apply/undo bracket, ``suppress_serialization_warning`` silences + # the ``to_netcdf`` UserWarning that fires when serializing in + # reformulated form (intentional here). + with sos_reformulation_context( + self, effective_inner, reformulate_sos + ) as applied: + with suppress_serialization_warning(active=applied): + remote_instance = remote_cls(settings) + self.remote = remote_instance + self.solver = None # remote-solve clears any prior local solver + solved = remote_instance.solve( + self, effective_inner, **effective_options + ) + return self._assign_from_solved_model(solved) + + def _solve_via_legacy_remote( + self, + remote: Any, + *, + solver_name: str | None, + io_api: str | None, + problem_fn: str | Path | None, + solution_fn: str | Path | None, + log_fn: str | Path | None, + basis_fn: str | Path | None, + warmstart_fn: str | Path | None, + keep_files: bool, + sanitize_zeros: bool, + reformulate_sos: bool | Literal["auto"], + solver_options: dict[str, Any], + ) -> tuple[str, str]: + """ + Back-compat path for ``Model.solve(remote=)``. + + Calls ``handler.solve_on_oetc(...)`` / ``handler.solve_on_remote(...)`` + as before — preserves the behavior tests on master are asserting + against — and emits a :class:`DeprecationWarning` pointing users at + the new ``remote=`` shape. + """ + if OetcHandler is not None and isinstance(remote, OetcHandler): + warnings.warn( + "Passing an OetcHandler via `remote=` is deprecated; pass " + "the OetcSettings directly: " + "`m.solve(remote=OetcSettings(...))`. The " + "`remote=OetcHandler/RemoteHandler` shape will be removed " + "in a future release.", + DeprecationWarning, + stacklevel=3, + ) + elif isinstance(remote, RemoteHandler): + warnings.warn( + "Passing a RemoteHandler via `remote=` is deprecated; pass " + "an SshSettings via `remote=` with a `solver_name=` for " + "the solver (`m.solve(solver_name, remote=SshSettings" + "(...))`). The `remote=OetcHandler/RemoteHandler` shape " + "will be removed in a future release.", + DeprecationWarning, + stacklevel=3, + ) + else: + raise TypeError( + f"`remote` must be an OetcHandler, RemoteHandler, " + f"OetcSettings, or SshSettings, got {type(remote).__name__}" + ) + + # The remote handlers short-circuit before reaching Solver.solve(), + # which is where the empty-objective check normally fires. Replicate + # it here. + if self.objective.expression.empty: + raise ValueError( + "No objective has been set on the model. Use " + "`m.add_objective(...)` first (e.g. `m.add_objective(0 * x)` " + "for a pure feasibility problem)." + ) + if OetcHandler is not None and isinstance(remote, OetcHandler): + solved = remote.solve_on_oetc( + self, + solver_name=solver_name, + reformulate_sos=reformulate_sos, + **solver_options, + ) + else: + solved = remote.solve_on_remote( + self, + solver_name=solver_name, + io_api=io_api, + problem_fn=problem_fn, + solution_fn=solution_fn, + log_fn=log_fn, + basis_fn=basis_fn, + warmstart_fn=warmstart_fn, + keep_files=keep_files, + sanitize_zeros=sanitize_zeros, + reformulate_sos=reformulate_sos, + **solver_options, + ) + + return self._assign_from_solved_model(solved) + def assign_result( self, result: Result, @@ -1889,6 +2060,27 @@ def assign_result( return status_value, termination_condition + def _assign_from_solved_model(self, solved: Model) -> tuple[str, str]: + """ + Fold a round-tripped solved model back onto this model in place. + + A remote worker produces a fully solved ``Model`` — variable + solutions, constraint duals, objective value, status. This copies + that data onto ``self`` (the model the caller invoked ``solve()`` + on), keyed by name. Used by the ``remote=`` paths of :meth:`solve`; + the local-solver path uses :meth:`assign_result` instead. + """ + if solved.objective.value is not None: + self.objective.set_value(float(solved.objective.value)) + self.status = solved.status + self.termination_condition = solved.termination_condition + for name, var in self.variables.items(): + var.solution = solved.variables[name].solution + for name, con in self.constraints.items(): + if "dual" in solved.constraints[name]: + con.dual = solved.constraints[name].dual + return self.status, self.termination_condition + def _mock_solve( self, sanitize_zeros: bool = True, diff --git a/linopy/remote/__init__.py b/linopy/remote/__init__.py index d3d5e162..c8642ec2 100644 --- a/linopy/remote/__init__.py +++ b/linopy/remote/__init__.py @@ -8,16 +8,19 @@ - OetcHandler: Cloud-based execution via OET Cloud service """ -from linopy.remote.ssh import RemoteHandler +from linopy.remote.ssh import SSH, RemoteHandler, SshSettings try: - from linopy.remote.oetc import OetcCredentials, OetcHandler, OetcSettings + from linopy.remote.oetc import Oetc, OetcCredentials, OetcHandler, OetcSettings except ImportError: pass __all__ = [ "RemoteHandler", + "SSH", + "SshSettings", "OetcHandler", + "Oetc", "OetcSettings", "OetcCredentials", ] diff --git a/linopy/remote/_common.py b/linopy/remote/_common.py new file mode 100644 index 00000000..5ade6b5e --- /dev/null +++ b/linopy/remote/_common.py @@ -0,0 +1,53 @@ +""" +Shared helper for the standalone remote classes (``Oetc``, ``SSH``). + +These classes do not inherit from :class:`linopy.solvers.Solver` — they're +a parallel concept. The helper here validates the solver string locally +before the round-trip to the worker, so an unknown name or an unsupported +feature fails fast instead of after the upload. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from linopy.model import Model + + +def _validate_inner_solver(inner_solver_name: str, model: Model) -> None: + """ + Check that the inner-solver string is locally known and + that the inner solver's feature set covers the model. + + Local installation is *not* required — feature flags are class-level + metadata. We only need the class to introspect ``supports(...)``. + Unknown solver names raise so typos fail fast instead of incurring a + round-trip to the worker. + """ + # Imported here to avoid a circular import at module load. + from linopy.solvers import SolverFeature, SolverName, _solver_class_for + + cls = _solver_class_for(inner_solver_name) + if cls is None: + valid = ", ".join(sorted(n.value for n in SolverName)) + raise ValueError( + f"Unknown solver name {inner_solver_name!r}. Pick one of: {valid}." + ) + if model.is_quadratic and not cls.supports(SolverFeature.QUADRATIC_OBJECTIVE): + raise ValueError( + f"Solver {inner_solver_name!r} does not support quadratic problems." + ) + if model.variables.semi_continuous and not cls.supports( + SolverFeature.SEMI_CONTINUOUS_VARIABLES + ): + raise ValueError( + f"Solver {inner_solver_name!r} does not support semi-continuous " + "variables. Use a solver that supports them (gurobi, cplex, highs)." + ) + if model.variables.sos and not cls.supports(SolverFeature.SOS_CONSTRAINTS): + raise ValueError( + f"Solver {inner_solver_name!r} does not support SOS constraints. " + "Reformulate first via `Model.solve(reformulate_sos=True)` or " + "`model.apply_sos_reformulation()`, or pick a solver that supports SOS." + ) diff --git a/linopy/remote/oetc.py b/linopy/remote/oetc.py index beef5873..73ef2b17 100644 --- a/linopy/remote/oetc.py +++ b/linopy/remote/oetc.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import contextlib import gzip import json import logging @@ -25,6 +26,8 @@ except ImportError: _oetc_deps_available = False +import warnings + import linopy from linopy.sos_reformulation import ( sos_reformulation_context, @@ -40,23 +43,74 @@ class ComputeProvider(str, Enum): @dataclass class OetcCredentials: + """ + .. deprecated:: + Pass ``email`` and ``password`` directly to :class:`OetcSettings` + instead of wrapping them in ``OetcCredentials``. This class will be + removed in a future release. + """ + email: str password: str + def __post_init__(self) -> None: + warnings.warn( + "`OetcCredentials` is deprecated; pass `email=` and `password=` " + "directly to `OetcSettings`. `OetcCredentials` will be removed " + "in a future release.", + DeprecationWarning, + stacklevel=2, + ) + @dataclass class OetcSettings: - credentials: OetcCredentials + """ + Connection config for the OET Cloud (OETC) remote service. + + Carries the auth/orchestrator endpoints and the worker resource + sizing. The solver is chosen per call — pass it to + :meth:`Model.solve` or :meth:`Oetc.submit`: + + >>> m.solve("gurobi", remote=OetcSettings(...), Method=2) # doctest: +SKIP + """ + name: str authentication_server_url: str orchestrator_server_url: str + email: str | None = None + password: str | None = None + credentials: OetcCredentials | None = None compute_provider: ComputeProvider = ComputeProvider.GCP - solver: str = "highs" - solver_options: dict[str, Any] = field(default_factory=dict) + solver: str | None = None + solver_options: dict[str, Any] | None = None cpu_cores: int = 2 disk_space_gb: int = 10 delete_worker_on_error: bool = False + def __post_init__(self) -> None: + if self.credentials is not None: + # `credentials=` warns from its own __post_init__; carry its + # values over unless `email` / `password` were also explicitly + # given (in which case the call site wins). + if self.email is None: + self.email = self.credentials.email + if self.password is None: + self.password = self.credentials.password + self.credentials = None + if not self.email or not self.password: + raise ValueError("`OetcSettings` requires `email` and `password`.") + if self.solver is not None or self.solver_options is not None: + warnings.warn( + "`OetcSettings.solver` and `OetcSettings.solver_options` are " + "deprecated and consulted only by the deprecated `OetcHandler`. " + "Pass the solver to `Model.solve(solver_name, remote=...)` or " + "`Oetc.submit(model, solver_name, ...)`. These fields will be " + "removed together with `OetcHandler`.", + DeprecationWarning, + stacklevel=2, + ) + @classmethod def from_env( cls, @@ -100,9 +154,8 @@ def from_env( ) kwargs: dict[str, Any] = { - "credentials": OetcCredentials( - email=resolved["email"], password=resolved["password"] - ), + "email": resolved["email"], + "password": resolved["password"], "name": resolved["name"], "authentication_server_url": resolved["authentication_server_url"], "orchestrator_server_url": resolved["orchestrator_server_url"], @@ -185,12 +238,30 @@ class JobResult: class OetcHandler: - def __init__(self, settings: OetcSettings) -> None: + """ + .. deprecated:: + Use :class:`~linopy.remote.Oetc` or :meth:`Model.solve(remote=OetcSettings(...)) + ` instead. This class will be removed in a + future release. The new :class:`Oetc` class owns the public lifecycle + (``submit`` / ``status`` / ``collect`` / ``solve``); ``OetcHandler`` + remains only for back-compat with code that holds a long-lived + handler instance. + """ + + def __init__(self, settings: OetcSettings, *, _internal: bool = False) -> None: if not _oetc_deps_available: raise ImportError( "The 'google-cloud-storage' and 'requests' packages are required " "for OetcHandler. Install them with: pip install linopy[oetc]" ) + if not _internal: + warnings.warn( + "`OetcHandler` is deprecated; use `Oetc(settings)` from " + "`linopy.remote` or `Model.solve(remote=OetcSettings(...))`. " + "`OetcHandler` will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) self.settings = settings self.jwt = self.__sign_in() self.cloud_provider_credentials = self.__get_cloud_provider_credentials() @@ -208,8 +279,8 @@ def __sign_in(self) -> AuthenticationResult: try: logger.info("OETC - Signing in...") payload = { - "email": self.settings.credentials.email, - "password": self.settings.credentials.password, + "email": self.settings.email, + "password": self.settings.password, } response = requests.post( @@ -419,6 +490,38 @@ def _get_job_logs(self, job_uuid: str) -> str: logger.warning(f"OETC - Error fetching logs for job {job_uuid}: {e}") return f"[Error fetching logs: {e}]" + def _get_job(self, job_uuid: str) -> JobResult: + """ + Fetch the current job record in a single request (no polling). + + Raises ``RequestException`` on a failed request and ``KeyError`` + if the response is missing required fields. + """ + response = requests.get( + f"{self.settings.orchestrator_server_url}/compute-job/{job_uuid}", + headers={ + "Authorization": f"{self.jwt.token_type} {self.jwt.token}", + "Content-Type": "application/json", + }, + timeout=30, + ) + response.raise_for_status() + job_data_dict = response.json() + return JobResult( + uuid=job_data_dict["uuid"], + status=job_data_dict["status"], + name=job_data_dict.get("name"), + owner=job_data_dict.get("owner"), + solver=job_data_dict.get("solver"), + duration_in_seconds=job_data_dict.get("duration_in_seconds"), + solving_duration_in_seconds=job_data_dict.get( + "solving_duration_in_seconds" + ), + input_files=job_data_dict.get("input_files", []), + output_files=job_data_dict.get("output_files", []), + created_at=job_data_dict.get("created_at"), + ) + def wait_and_get_job_data( self, job_uuid: str, @@ -449,33 +552,11 @@ def wait_and_get_job_data( logger.info(f"OETC - Waiting for job {job_uuid} to complete...") while True: + if self.jwt.is_expired: + logger.info("OETC - Auth token expired; re-authenticating.") + self.jwt = self.__sign_in() try: - response = requests.get( - f"{self.settings.orchestrator_server_url}/compute-job/{job_uuid}", - headers={ - "Authorization": f"{self.jwt.token_type} {self.jwt.token}", - "Content-Type": "application/json", - }, - timeout=30, - ) - - response.raise_for_status() - job_data_dict = response.json() - - job_result = JobResult( - uuid=job_data_dict["uuid"], - status=job_data_dict["status"], - name=job_data_dict.get("name"), - owner=job_data_dict.get("owner"), - solver=job_data_dict.get("solver"), - duration_in_seconds=job_data_dict.get("duration_in_seconds"), - solving_duration_in_seconds=job_data_dict.get( - "solving_duration_in_seconds" - ), - input_files=job_data_dict.get("input_files", []), - output_files=job_data_dict.get("output_files", []), - created_at=job_data_dict.get("created_at"), - ) + job_result = self._get_job(job_uuid) consecutive_failures = 0 @@ -645,11 +726,17 @@ def solve_on_oetc( """ Solve a linopy model on the OET Cloud compute app. + .. deprecated:: + Use :class:`Oetc` or + :meth:`Model.solve(remote=OetcSettings(...)) `. + Parameters ---------- model : linopy.model.Model solver_name : str, optional Override the solver from settings. + reformulate_sos : bool | "auto", optional + See :meth:`linopy.model.Model.solve`. **solver_options Override/extend solver_options from settings. @@ -657,55 +744,34 @@ def solve_on_oetc( ------- linopy.model.Model Solved model. - - Raises - ------ - Exception: If solving fails at any stage """ - try: - effective_solver = solver_name or self.settings.solver - merged_solver_options = {**self.settings.solver_options, **solver_options} + # Delegates to ``Oetc`` so the upload→submit→poll→download + # orchestration lives in one place. + effective_solver = solver_name or self.settings.solver or "highs" + merged_solver_options = { + **(self.settings.solver_options or {}), + **solver_options, + } + oetc = Oetc(settings=self.settings) + oetc._handler = self # reuse this handler so auth is not refetched + try: with sos_reformulation_context( model, effective_solver, reformulate_sos ) as applied: - with tempfile.NamedTemporaryFile(prefix="linopy-", suffix=".nc") as fn: - fn.file.close() - with suppress_serialization_warning(active=applied): - model.to_netcdf(fn.name) - input_file_name = self._upload_file_to_gcp(fn.name) - - job_uuid = self._submit_job_to_compute_service( - input_file_name, effective_solver, merged_solver_options - ) - job_result = self.wait_and_get_job_data(job_uuid) - - if not job_result.output_files: - raise Exception("No output files found in completed job") - - output_file_name = job_result.output_files[0] - if isinstance(output_file_name, dict) and "name" in output_file_name: - output_file_name = output_file_name["name"] - - solution_file_path = self._download_file_from_gcp(output_file_name) - - solved_model = linopy.read_netcdf(solution_file_path) - - os.remove(solution_file_path) - - logger.info( - f"OETC - Model solved successfully. Status: {solved_model.status}" - ) - if solved_model.objective.value is not None: - logger.info( - f"OETC - Objective value: {solved_model.objective.value:.2e}" - ) - - return solved_model - + with suppress_serialization_warning(active=applied): + job_uuid = oetc.submit( + model, effective_solver, **merged_solver_options + ) + solved_model = oetc.collect(job_uuid) except Exception as e: raise Exception(f"Error solving model on OETC: {e}") from e + logger.info(f"OETC - Model solved successfully. Status: {solved_model.status}") + if solved_model.objective.value is not None: + logger.info(f"OETC - Objective value: {solved_model.objective.value:.2e}") + return solved_model + def _gzip_compress(self, source_path: str) -> str: """ Compress a file using gzip compression. @@ -786,3 +852,97 @@ def _upload_file_to_gcp(self, file_path: str) -> str: except Exception as e: raise Exception(f"Failed to upload file to GCP: {e}") + + +@dataclass +class Oetc: + """ + A session with the OET Cloud (OETC) managed compute service. + + This is a standalone class — *not* a :class:`linopy.solvers.Solver` + subclass. An ``Oetc`` instance is a *session*, not a job: it holds an + auth token — not a persistent connection — and can submit and collect + any number of jobs over HTTPS. + + A job is identified solely by the uuid string returned from + :meth:`submit`. Because the handle is just a string, the lifecycle is + async-friendly — submit many models, hold their uuids, and + :meth:`collect` each when convenient, even from a different process + (a fresh ``Oetc(settings)`` re-authenticates and collects by uuid). + + Parameters + ---------- + settings : OetcSettings + Auth + orchestrator config (where to talk to). + """ + + settings: OetcSettings + + _handler: OetcHandler | None = field(init=False, default=None, repr=False) + + @classmethod + def is_available(cls) -> bool: + """Return True iff the OETC network deps are importable.""" + return _oetc_deps_available + + def _session(self) -> OetcHandler: + """ + Return the authenticated handler. + + Builds it on first use, and rebuilds it (re-authenticating) once + the previous auth token has expired — so a long-lived ``Oetc`` + keeps working across the token lifetime. + """ + if self._handler is None or self._handler.jwt.is_expired: + self._handler = OetcHandler(self.settings, _internal=True) + return self._handler + + def submit(self, model: Model, solver_name: str, **options: Any) -> str: + """ + Serialize and upload the model, submit the job, and return its uuid. + + The uuid is the only handle a job needs — persist it and + :meth:`collect` later, from this or any other process. + """ + handler = self._session() + with tempfile.NamedTemporaryFile(prefix="linopy-", suffix=".nc") as fn: + fn.file.close() + model.to_netcdf(fn.name) + input_file_name = handler._upload_file_to_gcp(fn.name) + return handler._submit_job_to_compute_service( + input_file_name, solver_name, dict(options) + ) + + def status(self, job_uuid: str) -> str: + """Return the current job status in a single, non-blocking request.""" + return self._session()._get_job(job_uuid).status + + def collect(self, job_uuid: str) -> Model: + """ + Block until the job finishes, download, and return the solved model. + + Needs only the uuid and the session, so it can run in a + different process than the one that called :meth:`submit`. + """ + handler = self._session() + job_result = handler.wait_and_get_job_data(job_uuid) + if not job_result.output_files: + raise Exception("No output files found in completed job") + output_file_name = job_result.output_files[0] + if isinstance(output_file_name, dict) and "name" in output_file_name: + output_file_name = output_file_name["name"] + + solution_file_path = handler._download_file_from_gcp(output_file_name) + try: + return linopy.read_netcdf(solution_file_path) + finally: + with contextlib.suppress(OSError): + os.remove(solution_file_path) + + def solve(self, model: Model, solver_name: str, **options: Any) -> Model: + """Submit the model and block until the solved model is back.""" + from linopy.remote._common import _validate_inner_solver + + _validate_inner_solver(solver_name, model) + job_uuid = self.submit(model, solver_name, **options) + return self.collect(job_uuid) diff --git a/linopy/remote/ssh.py b/linopy/remote/ssh.py index 7c0a0644..03c06d07 100644 --- a/linopy/remote/ssh.py +++ b/linopy/remote/ssh.py @@ -8,8 +8,9 @@ import logging import os import tempfile +import warnings from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal, Union from linopy.io import read_netcdf @@ -37,11 +38,42 @@ """ +@dataclass +class SshSettings: + """ + Connection config for the :class:`~linopy.remote.SSH` remote. + + Solver name and solver options come from :meth:`Model.solve` — + ``m.solve("gurobi", remote=SshSettings(hostname=...), presolve="on")``. + + Use ``setup_commands`` to prepare the remote shell before the solve — + e.g. activate a conda environment or set ``PATH``:: + + SshSettings(hostname=..., setup_commands=["conda activate linopy-env"]) + """ + + hostname: str + port: int = 22 + username: str | None = None + password: str | None = None + python_executable: str = "python" + python_file: str = "/tmp/linopy-execution.py" + model_unsolved_file: str = "/tmp/linopy-unsolved-model.nc" + model_solved_file: str = "/tmp/linopy-solved-model.nc" + setup_commands: list[str] = field(default_factory=list) + + @dataclass class RemoteHandler: """ Handler class for solving models on a remote machine via an SSH connection. + .. deprecated:: + ``RemoteHandler`` is the legacy low-level entry point and will be + removed in a future release. Prefer + ``Model.solve("gurobi", remote=SshSettings(hostname=...))`` or + instantiate :class:`SSH` directly. + The basic idea of the handler is to provide a workflow that: 1. defines a model on the local machine @@ -133,9 +165,20 @@ class RemoteHandler: model_unsolved_file: str = "/tmp/linopy-unsolved-model.nc" model_solved_file: str = "/tmp/linopy-solved-model.nc" + _internal: bool = field(default=False, repr=False) + def __post_init__(self) -> None: assert paramiko_present, "The required paramiko package is not installed." + if not self._internal: + warnings.warn( + "`RemoteHandler` is deprecated; use `SSH(settings)` from " + "`linopy.remote` or `Model.solve(remote=SshSettings(hostname=...))`. " + "`RemoteHandler` will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) + if self.client is None: client = paramiko.SSHClient() client.load_system_host_keys() @@ -256,3 +299,56 @@ def solve_on_remote( self.sftp_client.remove(self.model_solved_file) return solved + + +@dataclass +class SSH: + """ + A connection to a remote machine you run yourself, reached over SSH. + + This is a standalone class — *not* a :class:`linopy.solvers.Solver` + subclass. It ships the model to a remote host and runs + ``read_netcdf(...).solve(solver_name=...)`` there, pulling the solved + model back. Where :class:`Oetc` targets a managed cloud service, + ``SSH`` targets a machine you provide. The SSH shell job is + short-lived and synchronous, so there is no submit/collect seam — + just :meth:`solve`. + + Parameters + ---------- + settings : SshSettings + Connection + remote-execution paths. + """ + + settings: SshSettings + + _handler: "RemoteHandler | None" = field(init=False, default=None, repr=False) + + @classmethod + def is_available(cls) -> bool: + """Return True iff paramiko is importable.""" + return paramiko_present + + def solve(self, model: "Model", solver_name: str, **options: Any) -> "Model": + """Ship the model, run the solver on the remote, return the solved model.""" + from linopy.remote._common import _validate_inner_solver + + _validate_inner_solver(solver_name, model) + + if self._handler is None: + self._handler = RemoteHandler( + hostname=self.settings.hostname, + port=self.settings.port, + username=self.settings.username, + password=self.settings.password, + python_executable=self.settings.python_executable, + python_file=self.settings.python_file, + model_unsolved_file=self.settings.model_unsolved_file, + model_solved_file=self.settings.model_solved_file, + _internal=True, + ) + for cmd in self.settings.setup_commands: + self._handler.execute(cmd) + + solve_kwargs: dict[str, Any] = {"solver_name": solver_name, **options} + return self._handler.solve_on_remote(model, **solve_kwargs) diff --git a/linopy/solvers.py b/linopy/solvers.py index a28da898..6a6f23f0 100644 --- a/linopy/solvers.py +++ b/linopy/solvers.py @@ -515,11 +515,22 @@ def from_model( model: Model, io_api: str | None = None, options: dict[str, Any] | None = None, - **build_kwargs: Any, + **kwargs: Any, ) -> Solver: - """Instantiate and build the solver against ``model``.""" - instance = cls(model=model, io_api=io_api, options=options or {}) - instance._build(**build_kwargs) + """ + Instantiate and build the solver against ``model``. + + Any ``kwargs`` whose name matches an ``init=True`` dataclass field on + the subclass (e.g. ``settings`` on :class:`Oetc` / :class:`SSH`) are + forwarded to the constructor; the rest go to ``_build`` as + ``build_kwargs``. + """ + from dataclasses import fields + + field_names = {f.name for f in fields(cls) if f.init} + ctor_kw = {k: kwargs.pop(k) for k in list(kwargs) if k in field_names} + instance = cls(model=model, io_api=io_api, options=options or {}, **ctor_kw) + instance._build(**kwargs) return instance def _build(self, **build_kwargs: Any) -> None: diff --git a/pyproject.toml b/pyproject.toml index 67297677..ac916eb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ oetc = [ "google-cloud-storage", "requests", ] -remote = [ +ssh = [ "paramiko", ] docs = [ diff --git a/test/remote/test_oetc.py b/test/remote/test_oetc.py index 7b2d75f2..6aa06ff0 100644 --- a/test/remote/test_oetc.py +++ b/test/remote/test_oetc.py @@ -776,7 +776,7 @@ def handler_with_mocked_auth(self) -> OetcHandler: handler = OetcHandler.__new__(OetcHandler) handler.settings = settings - handler.jwt = Mock() + handler.jwt = Mock(is_expired=False) handler.cloud_provider_credentials = Mock() return handler @@ -865,7 +865,7 @@ def handler_with_gcp_credentials( handler = OetcHandler.__new__(OetcHandler) handler.settings = settings - handler.jwt = Mock() + handler.jwt = Mock(is_expired=False) handler.cloud_provider_credentials = gcp_creds return handler @@ -1009,7 +1009,7 @@ def handler_with_mocked_auth(self) -> OetcHandler: handler = OetcHandler.__new__(OetcHandler) handler.settings = settings - handler.jwt = Mock() + handler.jwt = Mock(is_expired=False) handler.cloud_provider_credentials = Mock() return handler @@ -1151,7 +1151,7 @@ def handler_with_gcp_credentials( handler = OetcHandler.__new__(OetcHandler) handler.settings = settings - handler.jwt = Mock() + handler.jwt = Mock(is_expired=False) handler.cloud_provider_credentials = gcp_creds return handler @@ -1512,7 +1512,7 @@ def handler_with_complete_setup( handler = OetcHandler.__new__(OetcHandler) handler.settings = settings - handler.jwt = Mock() + handler.jwt = Mock(is_expired=False) handler.cloud_provider_credentials = gcp_creds return handler @@ -1637,7 +1637,7 @@ def handler_with_full_setup(self) -> OetcHandler: handler = OetcHandler.__new__(OetcHandler) handler.settings = settings - handler.jwt = Mock() + handler.jwt = Mock(is_expired=False) handler.cloud_provider_credentials = gcp_creds return handler diff --git a/test/remote/test_oetc_job_polling.py b/test/remote/test_oetc_job_polling.py index 4b2681f9..6a0e7840 100644 --- a/test/remote/test_oetc_job_polling.py +++ b/test/remote/test_oetc_job_polling.py @@ -16,6 +16,7 @@ from linopy.remote.oetc import ( # noqa: E402 AuthenticationResult, ComputeProvider, + JobResult, OetcCredentials, OetcHandler, OetcSettings, @@ -166,6 +167,32 @@ def test_polling_interval_backoff( assert sleep_calls[0] == 10 # Initial interval assert sleep_calls[1] == 15 # 10 * 1.5 = 15 + def test_reauth_when_token_expired_during_poll( + self, mock_settings: OetcSettings + ) -> None: + """The poll loop re-signs-in when the auth token expires mid-poll.""" + fresh = AuthenticationResult("new", "Bearer", 3600, datetime.now()) + sign_in = Mock(return_value=fresh) + finished = JobResult( + uuid="job-1", status="FINISHED", output_files=["out.nc.gz"] + ) + with ( + patch("linopy.remote.oetc.OetcHandler._OetcHandler__sign_in", sign_in), + patch( + "linopy.remote.oetc.OetcHandler." + "_OetcHandler__get_cloud_provider_credentials" + ), + ): + handler = OetcHandler(mock_settings) + handler.jwt = AuthenticationResult("old", "Bearer", -1, datetime.now()) + sign_in.reset_mock() + with patch.object(handler, "_get_job", return_value=finished): + result = handler.wait_and_get_job_data("job-1") + + assert result.status == "FINISHED" + sign_in.assert_called_once() + assert handler.jwt is fresh + class TestJobPollingErrors: """Test job polling error scenarios.""" diff --git a/test/remote/test_remotes.py b/test/remote/test_remotes.py new file mode 100644 index 00000000..89473d6c --- /dev/null +++ b/test/remote/test_remotes.py @@ -0,0 +1,465 @@ +""" +Tests for the standalone remote classes (``Oetc`` / ``SSH``) and the +``Model.solve(remote=)`` entry point. + +The deprecated ``OetcHandler`` / ``RemoteHandler`` are covered by +``test_oetc.py`` and ``test_ssh.py`` separately; this file focuses on +the *new* public surface and its deprecation warnings. +""" + +from __future__ import annotations + +import warnings +from typing import Any +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from linopy import Model +from linopy.remote import ( + Oetc, + OetcCredentials, + OetcHandler, + OetcSettings, + RemoteHandler, + SshSettings, +) + +pytest.importorskip("paramiko") +from linopy.remote.ssh import SSH # noqa: E402 + +# --------------------------------------------------------------------------- +# Helpers + + +def _build_model() -> Model: + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_constraints(x >= 0, name="c") + m.add_objective(1.0 * x.sum()) + return m + + +def _settings_oetc() -> OetcSettings: + return OetcSettings( + email="a@b.com", + password="pw", + name="test-job", + authentication_server_url="https://auth", + orchestrator_server_url="https://orch", + ) + + +def _settings_ssh() -> SshSettings: + return SshSettings(hostname="example.org", username="me") + + +def _fake_oetc_handler() -> MagicMock: + """A MagicMock(spec=OetcHandler) with the methods Oetc.submit/status/collect call.""" + h = MagicMock(spec=OetcHandler) + h.jwt = MagicMock(is_expired=False) # a freshly authenticated handler + h._upload_file_to_gcp = MagicMock(return_value="model.nc.gz") + h._submit_job_to_compute_service = MagicMock(return_value="job-uuid") + job_result = MagicMock() + job_result.output_files = [{"name": "result.nc.gz"}] + job_result.duration_in_seconds = 42 + h.wait_and_get_job_data = MagicMock(return_value=job_result) + h._get_job = MagicMock(return_value=MagicMock(status="RUNNING")) + h._download_file_from_gcp = MagicMock(return_value="/tmp/fake-result.nc") + return h + + +def _solved_model_like(m: Model) -> Model: + """Build a Model with the same labels as ``m`` plus dummy solution data.""" + solved = Model() + for name, var in m.variables.items(): + solved_var = solved.add_variables( + lower=var.lower, upper=var.upper, coords=var.coords, name=name + ) + solved_var.solution = solved_var.lower * 0 # zeros, real DataArray + for name, con in m.constraints.items(): + solved.add_constraints(con.lhs >= con.rhs, name=name) + solved.add_objective(m.objective.expression) + solved.objective._value = 0.0 + solved.termination_condition = "optimal" + solved.status = "ok" + return solved + + +# --------------------------------------------------------------------------- +# Oetc class + + +class TestOetcClass: + def test_solve_runs_submit_and_collect( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + m = _build_model() + oetc = Oetc(_settings_oetc()) + oetc._handler = _fake_oetc_handler() # bypass auth + + monkeypatch.setattr( + "linopy.remote.oetc.linopy.read_netcdf", + lambda path: _solved_model_like(m), + ) + + result = oetc.solve(m, "highs") + + assert isinstance(result, Model) + oetc._handler._upload_file_to_gcp.assert_called_once() + oetc._handler._submit_job_to_compute_service.assert_called_once() + oetc._handler.wait_and_get_job_data.assert_called_once_with("job-uuid") + oetc._handler._download_file_from_gcp.assert_called_once_with("result.nc.gz") + + def test_validates_unknown_solver_name(self) -> None: + m = _build_model() + oetc = Oetc(_settings_oetc()) + oetc._handler = _fake_oetc_handler() + with pytest.raises(ValueError, match="Unknown solver"): + oetc.solve(m, "not-a-solver") + + def test_submit_collect_separable_by_uuid( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """The submit/collect seam can be driven manually for async work.""" + m = _build_model() + oetc = Oetc(_settings_oetc()) + oetc._handler = _fake_oetc_handler() + monkeypatch.setattr( + "linopy.remote.oetc.linopy.read_netcdf", + lambda path: _solved_model_like(m), + ) + + job_uuid = oetc.submit(m, "highs") + assert job_uuid == "job-uuid" + assert oetc._handler._upload_file_to_gcp.call_count == 1 + assert oetc._handler._submit_job_to_compute_service.call_count == 1 + + result = oetc.collect(job_uuid) + assert isinstance(result, Model) + oetc._handler.wait_and_get_job_data.assert_called_once_with("job-uuid") + + def test_status_returns_job_state(self) -> None: + oetc = Oetc(_settings_oetc()) + oetc._handler = _fake_oetc_handler() + assert oetc.status("job-uuid") == "RUNNING" + oetc._handler._get_job.assert_called_once_with("job-uuid") + + def test_one_connection_drives_multiple_jobs( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """A single Oetc connection submits and collects many models.""" + models = [_build_model() for _ in range(3)] + oetc = Oetc(_settings_oetc()) + oetc._handler = _fake_oetc_handler() + monkeypatch.setattr( + "linopy.remote.oetc.linopy.read_netcdf", + lambda path: _solved_model_like(models[0]), + ) + + uuids = [oetc.submit(m, "highs") for m in models] + assert len(uuids) == 3 + solved = [oetc.collect(u) for u in uuids] + assert all(isinstance(s, Model) for s in solved) + assert oetc._handler._submit_job_to_compute_service.call_count == 3 + assert oetc._handler.wait_and_get_job_data.call_count == 3 + + def test_collect_by_uuid_from_a_fresh_connection( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """A job uuid can be collected by an Oetc that never submitted it.""" + m = _build_model() + submitter = Oetc(_settings_oetc()) + submitter._handler = _fake_oetc_handler() + job_uuid = submitter.submit(m, "highs") + + # Simulate a separate process: a brand-new Oetc, given only the uuid. + collector = Oetc(_settings_oetc()) + collector._handler = _fake_oetc_handler() + monkeypatch.setattr( + "linopy.remote.oetc.linopy.read_netcdf", + lambda path: _solved_model_like(m), + ) + result = collector.collect(job_uuid) + assert isinstance(result, Model) + + def test_expired_token_triggers_reauth( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """A stale auth token makes the next call rebuild the handler.""" + oetc = Oetc(_settings_oetc()) + stale = _fake_oetc_handler() + stale.jwt = MagicMock(is_expired=True) + oetc._handler = stale + + rebuilt = _fake_oetc_handler() + monkeypatch.setattr( + "linopy.remote.oetc.OetcHandler", + lambda settings, _internal=False: rebuilt, + ) + + assert oetc.status("job-uuid") == "RUNNING" + assert oetc._handler is rebuilt # expired token -> reconnected + + +# --------------------------------------------------------------------------- +# SSH class + + +class TestSSHClass: + def test_solve_runs_setup_commands_then_delegates(self) -> None: + m = _build_model() + ssh = SSH( + SshSettings( + hostname="example.org", + setup_commands=["conda activate linopy-env", "export FOO=bar"], + ) + ) + fake_handler = MagicMock(spec=RemoteHandler) + fake_handler.execute = MagicMock() + fake_handler.solve_on_remote = MagicMock(return_value=_solved_model_like(m)) + ssh._handler = fake_handler + + result = ssh.solve(m, "highs") + + assert isinstance(result, Model) + # solve_on_remote is the public surface from the deprecated handler + fake_handler.solve_on_remote.assert_called_once() + # setup_commands run only on first handler construction; here _handler + # was injected, so they shouldn't run automatically: + fake_handler.execute.assert_not_called() + + def test_setup_commands_run_when_handler_is_built_internally( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """First .solve() with a fresh SSH builds a RemoteHandler and runs setup.""" + m = _build_model() + ssh = SSH( + SshSettings( + hostname="example.org", + setup_commands=["conda activate linopy-env"], + ) + ) + + built: list[Any] = [] + + class FakeRemoteHandler: + def __init__(self, **kwargs: Any) -> None: + self.kwargs = kwargs + self.execute = MagicMock() + self.solve_on_remote = MagicMock(return_value=_solved_model_like(m)) + built.append(self) + + monkeypatch.setattr("linopy.remote.ssh.RemoteHandler", FakeRemoteHandler) + ssh.solve(m, "highs") + + assert len(built) == 1 + built[0].execute.assert_called_once_with("conda activate linopy-env") + assert built[0].kwargs.get("_internal") is True + + def test_validates_unknown_solver_name(self) -> None: + m = _build_model() + ssh = SSH(_settings_ssh()) + ssh._handler = MagicMock(spec=RemoteHandler) + with pytest.raises(ValueError, match="Unknown solver"): + ssh.solve(m, "not-a-solver") + + +# --------------------------------------------------------------------------- +# Model.solve(remote=) end-to-end + + +class TestModelSolveRemote: + def test_oetc_settings_dispatches_to_oetc( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + m = _build_model() + captured: dict[str, Any] = {} + + def fake_solve( + self: Oetc, model: Model, solver_name: str, **options: Any + ) -> Model: + captured["solver_name"] = solver_name + captured["options"] = options + captured["instance"] = self + return _solved_model_like(model) + + monkeypatch.setattr(Oetc, "solve", fake_solve) + + m.solve("gurobi", remote=_settings_oetc(), Method=2) + + assert captured["solver_name"] == "gurobi" + assert captured["options"] == {"Method": 2} + assert m.remote is captured["instance"] + assert m.solver is None # remote-solve clears any prior local solver + + def test_oetc_settings_solver_used_when_no_solver_name( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """ + The deprecated `OetcSettings.solver` is the fallback when + `Model.solve(remote=...)` is called without a `solver_name`. + """ + m = _build_model() + captured: dict[str, Any] = {} + + def fake_solve( + self: Oetc, model: Model, solver_name: str, **options: Any + ) -> Model: + captured["solver_name"] = solver_name + captured["options"] = options + return _solved_model_like(model) + + monkeypatch.setattr(Oetc, "solve", fake_solve) + + with pytest.warns(DeprecationWarning, match=r"OetcSettings\.solver"): + settings = OetcSettings( + email="a@b.com", + password="pw", + name="test-job", + authentication_server_url="https://auth", + orchestrator_server_url="https://orch", + solver="cplex", + solver_options={"TimeLimit": 10}, + ) + m.solve(remote=settings) + + assert captured["solver_name"] == "cplex" + assert captured["options"] == {"TimeLimit": 10} + + def test_ssh_settings_dispatches_to_ssh( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + m = _build_model() + captured: dict[str, Any] = {} + + def fake_solve( + self: SSH, model: Model, solver_name: str, **options: Any + ) -> Model: + captured["solver_name"] = solver_name + captured["options"] = options + captured["instance"] = self + return _solved_model_like(model) + + monkeypatch.setattr(SSH, "solve", fake_solve) + + m.solve("highs", remote=_settings_ssh(), presolve="on") + + assert captured["solver_name"] == "highs" + assert captured["options"] == {"presolve": "on"} + assert m.remote is captured["instance"] + + @pytest.mark.parametrize( + ("remote_cls", "settings_factory", "solver"), + [(Oetc, _settings_oetc, "gurobi"), (SSH, _settings_ssh, "highs")], + ) + def test_remote_solve_writes_solution_onto_caller_model( + self, + monkeypatch: pytest.MonkeyPatch, + remote_cls: type, + settings_factory: Any, + solver: str, + ) -> None: + """ + `Model.solve(remote=...)` folds the solved model into the caller's + own model in place and returns the (status, termination_condition) + tuple — it never hands back the round-tripped model object. + """ + m = _build_model() + + def fake_solve( + self: Any, model: Model, solver_name: str, **options: Any + ) -> Model: + return _solved_model_like(model) + + monkeypatch.setattr(remote_cls, "solve", fake_solve) + + result = m.solve(solver, remote=settings_factory()) + + assert result == ("ok", "optimal") + assert m.status == "ok" + assert m.termination_condition == "optimal" + assert m.objective.value == 0.0 + assert float(m.variables["x"].solution.sum()) == 0.0 + + +# --------------------------------------------------------------------------- +# Deprecation warnings + + +class TestDeprecations: + def test_oetc_credentials_construction_warns(self) -> None: + with pytest.warns(DeprecationWarning, match="OetcCredentials"): + OetcCredentials(email="a@b.com", password="pw") + + def test_oetc_settings_credentials_kwarg_carries_values_through(self) -> None: + # Constructing OetcCredentials warns (its own __post_init__). + with pytest.warns(DeprecationWarning, match="OetcCredentials"): + creds = OetcCredentials(email="a@b.com", password="pw") + + s = OetcSettings( + credentials=creds, + name="n", + authentication_server_url="https://a", + orchestrator_server_url="https://o", + ) + assert s.email == "a@b.com" + assert s.password == "pw" + # `credentials` is consumed and cleared. + assert s.credentials is None + + def test_oetc_settings_requires_email_and_password(self) -> None: + with pytest.raises(ValueError, match="email.*password"): + OetcSettings( + name="n", + authentication_server_url="https://a", + orchestrator_server_url="https://o", + ) + + def test_oetc_handler_construction_warns(self) -> None: + with ( + patch.object(OetcHandler, "_OetcHandler__sign_in"), + patch.object(OetcHandler, "_OetcHandler__get_cloud_provider_credentials"), + ): + with pytest.warns(DeprecationWarning, match="OetcHandler"): + OetcHandler(_settings_oetc()) + + def test_oetc_handler_internal_construction_silent(self) -> None: + with ( + patch.object(OetcHandler, "_OetcHandler__sign_in"), + patch.object(OetcHandler, "_OetcHandler__get_cloud_provider_credentials"), + ): + with warnings.catch_warnings(): + warnings.simplefilter("error") + OetcHandler(_settings_oetc(), _internal=True) + + def test_remote_handler_construction_warns( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + fake_client = MagicMock() + fake_client.invoke_shell.return_value.makefile.return_value = MagicMock() + fake_client.open_sftp.return_value = MagicMock() + + with pytest.warns(DeprecationWarning, match="RemoteHandler"): + RemoteHandler(hostname="x", client=fake_client) + + def test_remote_handler_internal_construction_silent(self) -> None: + fake_client = MagicMock() + fake_client.invoke_shell.return_value.makefile.return_value = MagicMock() + fake_client.open_sftp.return_value = MagicMock() + + with warnings.catch_warnings(): + warnings.simplefilter("error") + RemoteHandler(hostname="x", client=fake_client, _internal=True) + + def test_model_solve_remote_handler_warns( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + m = _build_model() + handler = MagicMock(spec=OetcHandler) + handler.settings = _settings_oetc() + handler.solve_on_oetc = MagicMock(return_value=_solved_model_like(m)) + with pytest.warns(DeprecationWarning, match="OetcHandler.*remote="): + m.solve(solver_name="highs", remote=handler) diff --git a/test/test_oetc_settings.py b/test/test_oetc_settings.py index 12deeb66..99323094 100644 --- a/test/test_oetc_settings.py +++ b/test/test_oetc_settings.py @@ -7,7 +7,6 @@ from linopy.remote.oetc import ( ComputeProvider, - OetcCredentials, OetcHandler, OetcSettings, ) @@ -48,8 +47,8 @@ def test_from_env_all_set(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("OETC_DELETE_WORKER_ON_ERROR", "true") s = OetcSettings.from_env() - assert s.credentials.email == "test@example.com" - assert s.credentials.password == "secret" + assert s.email == "test@example.com" + assert s.password == "secret" assert s.name == "test-job" assert s.cpu_cores == 8 assert s.disk_space_gb == 20 @@ -62,7 +61,7 @@ def test_from_env_kwargs_override(monkeypatch: pytest.MonkeyPatch) -> None: _set_required_env(monkeypatch) s = OetcSettings.from_env(email="override@example.com") - assert s.credentials.email == "override@example.com" + assert s.email == "override@example.com" def test_from_env_missing_required(monkeypatch: pytest.MonkeyPatch) -> None: @@ -93,7 +92,7 @@ def test_from_env_partial_kwargs(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("OETC_ORCHESTRATOR_URL", "https://orch.example.com") s = OetcSettings.from_env(email="a@b.com", password="pw") - assert s.credentials.email == "a@b.com" + assert s.email == "a@b.com" assert s.name == "env-name" @@ -102,14 +101,28 @@ def test_from_env_defaults_applied(monkeypatch: pytest.MonkeyPatch) -> None: _set_required_env(monkeypatch) s = OetcSettings.from_env() - assert s.solver == "highs" - assert s.solver_options == {} + assert s.solver is None + assert s.solver_options is None assert s.cpu_cores == 2 assert s.disk_space_gb == 10 assert s.compute_provider == ComputeProvider.GCP assert s.delete_worker_on_error is False +def test_solver_fields_emit_deprecation_warning() -> None: + base: dict[str, Any] = dict( + email="a@b.com", + password="pw", + name="test", + authentication_server_url="https://auth", + orchestrator_server_url="https://orch", + ) + with pytest.warns(DeprecationWarning, match=r"OetcSettings\.solver"): + OetcSettings(**base, solver="gurobi") + with pytest.warns(DeprecationWarning, match=r"OetcSettings\.solver"): + OetcSettings(**base, solver_options={"TimeLimit": 100}) + + def test_from_env_cpu_cores_valid(monkeypatch: pytest.MonkeyPatch) -> None: _clear_oetc_env(monkeypatch) _set_required_env(monkeypatch) @@ -157,7 +170,11 @@ def test_from_env_bool_invalid(monkeypatch: pytest.MonkeyPatch) -> None: def _make_handler(settings: OetcSettings) -> OetcHandler: with ( patch("linopy.remote.oetc._oetc_deps_available", True), - patch.object(OetcHandler, "_OetcHandler__sign_in", return_value=MagicMock()), + patch.object( + OetcHandler, + "_OetcHandler__sign_in", + return_value=MagicMock(is_expired=False), + ), patch.object( OetcHandler, "_OetcHandler__get_cloud_provider_credentials", @@ -169,7 +186,8 @@ def _make_handler(settings: OetcSettings) -> OetcHandler: def _default_settings(**overrides: Any) -> OetcSettings: defaults: dict[str, Any] = dict( - credentials=OetcCredentials(email="a@b.com", password="pw"), + email="a@b.com", + password="pw", name="test", authentication_server_url="https://auth", orchestrator_server_url="https://orch", @@ -183,7 +201,7 @@ def _default_settings(**overrides: Any) -> OetcSettings: def test_solve_on_oetc_mutation_safety() -> None: settings = _default_settings() handler = _make_handler(settings) - original_opts = dict(settings.solver_options) + original_opts = dict(settings.solver_options or {}) mock_model = MagicMock() mock_solved = MagicMock() diff --git a/test/test_sos_reformulation.py b/test/test_sos_reformulation.py index 0e9dc9da..765de6a1 100644 --- a/test/test_sos_reformulation.py +++ b/test/test_sos_reformulation.py @@ -6,16 +6,14 @@ import warnings from collections.abc import Callable from pathlib import Path -from typing import Literal, cast +from typing import Any import numpy as np import pandas as pd import pytest -import xarray as xr from linopy import Model, Variable, available_solvers from linopy.constants import SOS_TYPE_ATTR -from linopy.remote import RemoteHandler from linopy.sos_reformulation import ( compute_big_m_values, reformulate_sos1, @@ -1188,64 +1186,54 @@ def _sos_model() -> Model: m.add_objective(x * np.array([1.0, 2.0, 3.0]), sense="max") return m - def _fake_handler( - self, observed: dict[str, object], tmp_path: Path - ) -> RemoteHandler: + @staticmethod + def _patch_ssh_solve( + monkeypatch: pytest.MonkeyPatch, + observed: dict[str, object], + tmp_path: Path, + ) -> None: """ - Non-OetcHandler stand-in with the SSH-shaped `solve_on_remote`. - - Records whether the model arrives in reformulated form, then runs - `model.to_netcdf(...)` and `read_netcdf(...)` (naturally — no - warning recording here, so we can observe at the call-site whether - Model.solve's suppression worked). + Replace ``linopy.remote.ssh.SSH.solve`` with a stub that records + whether the model arrives in reformulated form, exercises the + ``to_netcdf`` warning path, and returns a synthetic solved + :class:`Model` so ``Model._assign_from_solved_model`` is exercised + end to end. """ - from linopy.io import read_netcdf - from linopy.sos_reformulation import ( - sos_reformulation_context, - suppress_serialization_warning, - ) + from linopy.remote.ssh import SSH + + def fake_solve( + self: SSH, model: Model, solver_name: str, **options: Any + ) -> Model: + observed["state_active"] = model._sos_reformulation_state is not None + observed["solver_name_arg"] = solver_name + model.to_netcdf(tmp_path / "sent.nc") # triggers any to_netcdf warning + for _name, var in model.variables.items(): + var.solution = var.labels * 0.0 + model.objective._value = 0.0 + model.status = "ok" + model.termination_condition = "optimal" + return model + + monkeypatch.setattr(SSH, "solve", fake_solve) + + def test_remote_brackets_and_suppresses_warning( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + from linopy.remote.ssh import SshSettings - class _Handler: - def solve_on_remote( - _self, - model: Model, - *, - reformulate_sos: bool | Literal["auto"] = False, - **kwargs: object, - ) -> Model: - solver_name = kwargs.get("solver_name") - assert solver_name is None or isinstance(solver_name, str) - with sos_reformulation_context( - model, solver_name, reformulate_sos - ) as applied: - observed["state_active"] = ( - model._sos_reformulation_state is not None - ) - observed["solver_name_arg"] = solver_name - with suppress_serialization_warning(active=applied): - model.to_netcdf(tmp_path / "sent.nc") - solved = read_netcdf(tmp_path / "sent.nc") - for _name, var in solved.variables.items(): - arr = np.zeros(var.labels.shape, dtype=float) - var.solution = xr.DataArray(arr, dims=var.labels.dims) - solved.objective.set_value(0.0) - solved.status = "ok" - solved.termination_condition = "optimal" - return solved - - return cast(RemoteHandler, _Handler()) - - def test_remote_brackets_and_suppresses_warning(self, tmp_path: Path) -> None: m = self._sos_model() observed: dict[str, object] = {} - handler = self._fake_handler(observed, tmp_path) + self._patch_ssh_solve(monkeypatch, observed, tmp_path) with warnings.catch_warnings(record=True) as captured: warnings.simplefilter("always") - m.solve(solver_name="highs", remote=handler, reformulate_sos=True) + m.solve( + solver_name="highs", + remote=SshSettings(hostname="ignored"), + reformulate_sos=True, + ) - # Reformulation was active when the handler ran (apply happened - # before the remote dispatch). + # Reformulation was active when the transport ran. assert observed["state_active"] is True assert observed["solver_name_arg"] == "highs" @@ -1258,26 +1246,38 @@ def test_remote_brackets_and_suppresses_warning(self, tmp_path: Path) -> None: assert "_sos_reform_x_y" not in m.variables def test_remote_skips_bracket_when_reformulate_sos_false( - self, tmp_path: Path + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: + from linopy.remote.ssh import SshSettings + m = self._sos_model() observed: dict[str, object] = {} - handler = self._fake_handler(observed, tmp_path) + self._patch_ssh_solve(monkeypatch, observed, tmp_path) with warnings.catch_warnings(record=True) as captured: warnings.simplefilter("always") - m.solve(solver_name="highs", remote=handler, reformulate_sos=False) + m.solve( + solver_name="highs", + remote=SshSettings(hostname="ignored"), + reformulate_sos=False, + ) # No reformulation happened — model still has the original SOS var - # when the handler sees it, and to_netcdf never warns. + # when the transport sees it, and to_netcdf never warns. assert observed["state_active"] is False assert not any("active SOS reformulation" in str(w.message) for w in captured) assert m._sos_reformulation_state is None - def test_remote_auto_requires_solver_name_with_sos(self, tmp_path: Path) -> None: + def test_remote_auto_requires_solver_name_with_sos( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + from linopy.remote.ssh import SshSettings + m = self._sos_model() observed: dict[str, object] = {} - handler = self._fake_handler(observed, tmp_path) + self._patch_ssh_solve(monkeypatch, observed, tmp_path) - with pytest.raises(ValueError, match="requires an explicit `solver_name`"): - m.solve(remote=handler, reformulate_sos="auto") + # Without an explicit solver_name, the transport dispatch refuses + # to run because there's no inner solver to ship. + with pytest.raises(ValueError, match="explicit `solver_name=`"): + m.solve(remote=SshSettings(hostname="ignored"), reformulate_sos="auto")