# Source code for astroNN.neuralode.odeint

import tensorflow as tf
from astroNN.neuralode.dop853 import dop853
from astroNN.neuralode.runge_kutta import rk4

method_list = {'dop853': dop853, 'rk4': rk4}

[docs]def odeint(func=None, x=None, t=None, aux=None, method='dop853', precision=tf.float32, *args, **kwargs):
"""
To computes the numerical solution of a system of first order ordinary differential equations y'=f(x,y). Default
precision at float32.

:param func: function of the differential equation, usually take func([position, velocity], time) and return velocity, acceleration
:type func: callable
:param x: initial x, usually is [position, velocity]
:type x: Union([tf.Tensor, numpy.ndarray, list])
:param t: set of times at which one wants the result
:type t: Union([tf.Tensor, numpy.ndarray, list])
:param method: numerical integrator to use, available integrators are ['dop853', 'rk4']
:type method: str
:param precision: float precision, tf.float32 or tf.float64
:type precision: type
:param t: set of times at which one wants the result
:type t: Union([tf.Tensor, numpy.ndarray, list])

:return: integrated result
:rtype: tf.Tensor

:History: 2020-May-31 - Written - Henry Leung (University of Toronto)
"""
try:
ode_method = method_list[method.lower()]
except KeyError:
raise NotImplementedError(f"Method {method} is not implemented")

# check things if they are tensors
if not isinstance(x, tf.Tensor):
x = tf.constant(x)
if not isinstance(t, tf.Tensor):
t = tf.constant(t)

if precision == tf.float32:
tf_float = tf.float32
elif precision == tf.float64:
tf_float = tf.float64
else:
raise TypeError(f"Data type {precision} not understood")

x = tf.cast(x, tf_float)
t = tf.cast(t, tf_float)

if aux is not None:
aux_flag = True
else:
aux_flag = False

if not isinstance(aux, tf.Tensor) and aux_flag:
aux = tf.constant(aux, dtype=tf_float)

@tf.function
def wrapped_func(x, t, *args, **kwargs):
return func(x, t, *args, **kwargs)

if not aux_flag:
if len(x.shape) < 2:  # ensure multi-dim
return ode_method(func=wrapped_func, x=x, t=t, tf_float=tf_float, *args, **kwargs)[0]
else:
total_num = x.shape[0]

if len(t.shape) < 2:
t = tf.stack([t] * total_num)

def odeint_external(tensor):
return ode_method(func=wrapped_func, x=tensor[0], t=tensor[1], tf_float=tf_float, *args, **kwargs)

@tf.function
def parallelized_func(tensor):
return tf.map_fn(odeint_external, tensor)

# result in (x, t, aux)
result = parallelized_func((x, t))
else:
if len(x.shape) < 2:  # ensure multi-dim
return ode_method(func=wrapped_func, x=x, t=t, aux=aux, tf_float=tf_float, *args, **kwargs)[0]
else:
total_num = x.shape[0]

if len(t.shape) < 2:
t = tf.stack([t] * total_num)

def odeint_external(tensor):
return ode_method(func=wrapped_func, x=tensor[0], t=tensor[1], aux=tensor[2], tf_float=tf_float, *args, **kwargs)

@tf.function
def parallelized_func(tensor):
return tf.map_fn(odeint_external, tensor)

# result in (x, t, aux)
result = parallelized_func((x, t, aux))

return result[0]