Skip to content

File OrbitalElements.cpp

File List > astrea > astro > astro > state > orbital_elements > OrbitalElements.cpp

Go to the documentation of this file

/*
 * The GNU Lesser General Public License (LGPL)
 *
 * Copyright (c) 2025 Jay Iuliano
 *
 * This file is part of Astrea.
 * Astrea is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
 * Astrea is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
 * of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should
 * have received a copy of the GNU General Public License along with Astrea. If not, see <https://www.gnu.org/licenses/>.
 */

#include <astro/state/orbital_elements/OrbitalElements.hpp>

#include <cassert>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <stdexcept>

#include <mp-units/math.h>

#include <math/operations.hpp>

using namespace mp_units;

namespace astrea {
namespace astro {

std::ostream& operator<<(std::ostream& os, const OrbitalElements& elements)
{
    std::visit([&os](const auto& x) { os << x; }, elements._elements);
    return os;
}

bool OrbitalElements::operator==(const OrbitalElements& other) const
{
    if (_elements.index() != other.extract().index()) [[unlikely]] { return false; }
    return std::visit(
        [&](const auto& x) -> bool {
            const auto& y = std::get<std::remove_cvref_t<decltype(x)>>(other._elements);
            return x == y;
        },
        _elements
    );
}

OrbitalElements OrbitalElements::operator+(const OrbitalElements& other) const
{
    return std::visit(
        [&](const auto& x) -> OrbitalElements {
            if (!std::holds_alternative<std::remove_cvref_t<decltype(x)>>(other._elements)) {
                throw_mismatched_types();
            }
            const auto& y = std::get<std::remove_cvref_t<decltype(x)>>(other._elements);
            return x + y;
        },
        _elements
    );
}
OrbitalElements& OrbitalElements::operator+=(const OrbitalElements& other)
{
    std::visit(
        [&](auto& x) {
            if (!std::holds_alternative<std::remove_cvref_t<decltype(x)>>(other._elements)) {
                throw_mismatched_types();
            }
            const auto& y = std::get<std::remove_cvref_t<decltype(x)>>(other._elements);
            x += y;
        },
        _elements
    );
    return *this;
}

OrbitalElements OrbitalElements::operator-(const OrbitalElements& other) const
{
    return std::visit(
        [&](const auto& x) -> OrbitalElements {
            if (!std::holds_alternative<std::remove_cvref_t<decltype(x)>>(other._elements)) {
                throw_mismatched_types();
            }
            const auto& y = std::get<std::remove_cvref_t<decltype(x)>>(other._elements);
            return x - y;
        },
        _elements
    );
}
OrbitalElements& OrbitalElements::operator-=(const OrbitalElements& other)
{
    std::visit(
        [&](auto& x) {
            if (!std::holds_alternative<std::remove_cvref_t<decltype(x)>>(other._elements)) {
                throw_mismatched_types();
            }
            const auto& y = std::get<std::remove_cvref_t<decltype(x)>>(other._elements);
            x -= y;
        },
        _elements
    );
    return *this;
}

OrbitalElements OrbitalElements::operator*(const Unitless& multiplier) const
{
    return std::visit([&](const auto& x) -> OrbitalElements { return x * multiplier; }, _elements);
}
OrbitalElements& OrbitalElements::operator*=(const Unitless& multiplier)
{
    std::visit([&](auto& x) { x *= multiplier; }, _elements);
    return *this;
}

OrbitalElementPartials OrbitalElements::operator/(const Time& divisor) const
{
    return std::visit([&](const auto& x) -> OrbitalElementPartials { return x / divisor; }, _elements);
}
OrbitalElements OrbitalElements::operator/(const Unitless& divisor) const
{
    return std::visit([&](const auto& x) -> OrbitalElements { return x / divisor; }, _elements);
}
OrbitalElements& OrbitalElements::operator/=(const Unitless& divisor)
{
    std::visit([&](auto& x) { x /= divisor; }, _elements);
    return *this;
}

std::vector<Unitless> OrbitalElements::force_to_vector() const
{
    return std::visit([&](const auto& x) -> std::vector<Unitless> { return x.force_to_vector(); }, _elements);
}

OrbitalElements
    OrbitalElements::interpolate(const Time& thisTime, const Time& otherTime, const OrbitalElements& other, const GravParam& mu, const Time& targetTime) const
{
    return std::visit(
        [&](const auto& x) -> OrbitalElements {
            if (!std::holds_alternative<std::remove_cvref_t<decltype(x)>>(other._elements)) {
                throw_mismatched_types();
            }
            const auto& y = std::get<std::remove_cvref_t<decltype(x)>>(other._elements);
            return x.interpolate(thisTime, otherTime, y, mu, targetTime);
        },
        _elements
    );
}

const OrbitalElements::ElementVariant& OrbitalElements::extract() const { return _elements; }
OrbitalElements::ElementVariant& OrbitalElements::extract() { return _elements; }

OrbitalElements& OrbitalElements::convert_to_set(const std::size_t idx, const GravParam& mu)
{
    *this = convert_to_set_impl(idx, mu);
    return *this;
}

OrbitalElements OrbitalElements::convert_to_set(const std::size_t idx, const GravParam& mu) const
{
    return convert_to_set_impl(idx, mu);
}

OrbitalElements OrbitalElements::convert_to_set_impl(const std::size_t idx, const GravParam& mu) const
{
    // TODO: Surely, there's a better way to do this
    switch (idx) { // ooh boy we're fragile
        case (OrbitalElements::get_set_id<Cartesian>()): return in_element_set<Cartesian>(mu);
        case (OrbitalElements::get_set_id<Keplerian>()): return in_element_set<Keplerian>(mu);
        case (OrbitalElements::get_set_id<Equinoctial>()): return in_element_set<Equinoctial>(mu);
        default: throw std::runtime_error("Unrecognized element set requested.");
    }
}

OrbitalElements OrbitalElements::from_vector(const std::vector<Unitless>& vec, const std::size_t idx)
{
    switch (idx) {
        case OrbitalElements::get_set_id<Cartesian>(): return OrbitalElements(Cartesian::from_vector(vec));
        case OrbitalElements::get_set_id<Keplerian>(): return OrbitalElements(Keplerian::from_vector(vec));
        case OrbitalElements::get_set_id<Equinoctial>(): return OrbitalElements(Equinoctial::from_vector(vec));
        default: throw std::runtime_error("Invalid orbital element set index for from_vector.");
    }
}


OrbitalElements OrbitalElementPartials::operator*(const Time& time) const
{
    return std::visit([&](const auto& x) -> OrbitalElements { return x * time; }, _elements);
}

std::ostream& operator<<(std::ostream& os, const OrbitalElementPartials& elements)
{
    std::visit([&os](const auto& x) { os << x; }, elements._elements);
    return os;
}

const OrbitalElementPartials::PartialVariant& OrbitalElementPartials::extract() const { return _elements; }

OrbitalElementPartials::PartialVariant& OrbitalElementPartials::extract() { return _elements; }

std::vector<Unitless> OrbitalElementPartials::force_to_vector() const
{
    return std::visit([&](const auto& x) -> std::vector<Unitless> { return x.force_to_vector(); }, _elements);
}

void throw_mismatched_types()
{
    throw std::runtime_error("Cannot perform operations on orbital elements from different "
                             "element sets.");
}


bool nearly_equal(const OrbitalElements& first, const OrbitalElements& second, bool ignoreFastVariable, Unitless relTol)
{
    if (first.index() != second.index()) { throw_mismatched_types(); }

    const std::vector<Unitless> firstScaled  = first.force_to_vector();
    const std::vector<Unitless> secondScaled = second.force_to_vector();
    for (int ii = 0; ii < 6; ii++) {
        if (!math::nearly_equal(firstScaled[ii], secondScaled[ii], relTol)) { return false; }
    }
    return true;
}

bool nearly_equal(const OrbitalElementPartials& first, const OrbitalElementPartials& second, bool ignoreFastVariable, Unitless relTol)
{
    if (first.index() != second.index()) { throw_mismatched_types(); }

    // arbitrary normalization. shouldn't affect relative size
    const Time scale                         = 1.0 * mp_units::si::unit_symbols::s;
    const std::vector<Unitless> firstScaled  = (first * scale).force_to_vector();
    const std::vector<Unitless> secondScaled = (second * scale).force_to_vector();
    for (int ii = 0; ii < 6; ii++) {
        if (!math::nearly_equal(firstScaled[ii], secondScaled[ii], relTol)) { return false; }
    }
    return true;
}

} // namespace astro
} // namespace astrea