Source code for torchdyn.numerics.utils

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

    Contains various utilities for `odeint` and numerical methods. Various norms, step size initialization, event callbacks for hybrid systems, vmapped matrix-Jacobian products and some
    additional goodies.
import attr
import torch
import torch.nn as nn
from torch.distributions import Exponential
from torchcde import CubicSpline, hermite_cubic_coefficients_with_backward_differences

[docs]def make_norm(state): state_size = state.numel() def norm_(aug_state): y = aug_state[1:1 + state_size] adj_y = aug_state[1 + state_size:1 + 2 * state_size] return max(hairer_norm(y), hairer_norm(adj_y)) return norm_
[docs]def hairer_norm(tensor): return tensor.pow(2).mean().sqrt()
[docs]def init_step(f, f0, x0, t0, order, atol, rtol): scale = atol + torch.abs(x0) * rtol d0, d1 = hairer_norm(x0 / scale), hairer_norm(f0 / scale) if d0 < 1e-5 or d1 < 1e-5: h0 = torch.tensor(1e-6, dtype=x0.dtype, device=x0.device) else: h0 = 0.01 * d0 / d1 x_new = x0 + h0 * f0 f_new = f(t0 + h0, x_new) d2 = hairer_norm((f_new - f0) / scale) / h0 if d1 <= 1e-15 and d2 <= 1e-15: h1 = torch.max(torch.tensor(1e-6, dtype=x0.dtype, device=x0.device), h0 * 1e-3) else: h1 = (0.01 / max(d1, d2)) ** (1. / float(order + 1)) dt = torch.min(100 * h0, h1).to(t0) return dt
[docs]@torch.no_grad() def adapt_step(dt, error_ratio, safety, min_factor, max_factor, order): if error_ratio == 0: return dt * max_factor if error_ratio < 1: min_factor = torch.ones_like(dt) exponent = torch.tensor(order, dtype=dt.dtype, device=dt.device).reciprocal() factor = torch.min(max_factor, torch.max(safety / error_ratio ** exponent, min_factor)) return dt * factor
[docs]def dense_output(sol, t_sol, t_eval, return_spline=False): t_sol = spline_coeff = hermite_cubic_coefficients_with_backward_differences(t_sol, sol.permute(1, 0, 2)) sol_spline = CubicSpline(t_sol, spline_coeff) sol_eval = torch.stack([sol_spline.evaluate(t) for t in t_eval]) if return_spline: return sol_eval, sol_spline return sol_eval
[docs]class EventState: def __init__(self, evid): self.evid = evid def __ne__(self, other): return sum([a_ != b_ for a_, b_ in zip(self.evid, other.evid)])
[docs]@attr.s class EventCallback(nn.Module): "Basic callback for hybrid differential equations. Must define an event condition and a state-jump" def __attrs_post_init__(self): super().__init__()
[docs] def check_event(self, t, x): raise NotImplementedError
[docs] def jump_map(self, t, x): raise NotImplementedError
[docs]@attr.s class StochasticEventCallback(nn.Module): def __attrs_post_init__(self): super().__init__() self.expdist = Exponential(1)
[docs] def initialize(self, x0): self.s = self.expdist.sample(x0.shape[:1])
[docs] def check_event(self, t, x): raise NotImplementedError
[docs] def jump_map(self, t, x): raise NotImplementedError
[docs]class RootLogger(object): def __init__(self): = {'geval': [], 'z': [], 'dz': [], 'iteration': [], 'alpha': [], 'phi': []}
[docs] def log(self, logged_data):**logged_data)
[docs] def permanent_log(self, logged_data): for key in{key: list([key] + logged_data[key])})
[docs]class WrapFunc(nn.Module): def __init__(self, f): super().__init__() self.f = f
[docs] def forward(self, t, x): return self.f(x)