# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains several Interpolator classes"""
import torch
from torchdyn.numerics.solvers._constants import construct_4th
[docs]class Interpolator:
def __init__(self, order):
self.order = order
[docs] def sync_device_dtype(self, x, t_span):
"Ensures `x`, `t_span`, `tableau` and other interpolator tensors are on the same device with compatible dtypes"
if self.bmid is not None: self.bmid = self.bmid.to(x)
return x, t_span
[docs] def fit(self, f0, f1, x0, x1, t, dt, **kwargs):
pass
[docs] def evaluate(self, coefs, t0, t1, t):
"Evaluates a generic interpolant given coefs between [t0, t1]."
theta = (t - t0) / (t1 - t0)
result = coefs[0] + theta * coefs[1]
theta_power = theta
for coef in coefs[2:]:
theta_power = theta_power * theta
result += theta_power * coef
return result
[docs]class Linear(Interpolator):
def __init__(self):
raise NotImplementedError
[docs]class ThirdHermite(Interpolator):
def __init__(self):
super().__init__(order=3)
raise NotImplementedError
[docs]class FourthOrder(Interpolator):
def __init__(self, dtype):
"""4th order interpolation scheme."""
super().__init__(order=4)
self.bmid = construct_4th(dtype)
[docs] def fit(self, dt, f0, f1, x0, x1, x_mid, **kwargs):
c1 = 2 * dt * (f1 - f0) - 8 * (x1 + x0) + 16 * x_mid
c2 = dt * (5 * f0 - 3 * f1) + 18 * x0 + 14 * x1 - 32 * x_mid
c3 = dt * (f1 - 4 * f0) - 11 * x0 - 5 * x1 + 16 * x_mid
c4 = dt * f0
c5 = x0
return [c5, c4, c3, c2, c1]
INTERP_DICT = {'4th': FourthOrder}
[docs]def str_to_interp(solver_name, dtype=torch.float32):
"Transforms string specifying desired interpolation scheme into an instance of the Interpolator class."
interpolator = INTERP_DICT[solver_name]
return interpolator(dtype)