// =============================================================================
// PROJECT CHRONO - http://projectchrono.org
//
// Copyright (c) 2016 projectchrono.org
// All rights reserved.
//
// Use of this source code is governed by a BSD-style license that can be found
// in the LICENSE file at the top level of the distribution and at
// http://projectchrono.org/license-chrono.txt.
//
// =============================================================================
// Authors: Hammad Mazhar, Radu Serban
// =============================================================================
//
// Handling of bilateral constraints for the system and Jacobian calculation
//
// =============================================================================

#include <algorithm>

#include "chrono_multicore/constraints/ChConstraintBilateral.h"
#include "chrono_multicore/ChMulticoreDefines.h"
#include "chrono/multicore_math/ChMulticoreMath.h"

#include "chrono/solver/ChConstraintTwoBodies.h"
#include "chrono/solver/ChConstraintTwoGeneric.h"
#include "chrono/solver/ChConstraintThreeGeneric.h"
#include "chrono/physics/ChBody.h"
#include "chrono/physics/ChShaft.h"

using namespace chrono;

void ChConstraintBilateral::Build_b() {
    std::vector<ChConstraint*>& mconstraints = data_manager->system_descriptor->GetConstraints();

#pragma omp parallel for
    for (int index = 0; index < (signed)data_manager->num_bilaterals; index++) {
        int cntr = data_manager->host_data.bilateral_mapping[index];
        ChConstraintTwoBodies* mbilateral = (ChConstraintTwoBodies*)(mconstraints[cntr]);
        data_manager->host_data.b[index + data_manager->num_unilaterals] = mbilateral->GetRightHandSide();
    }
}

void ChConstraintBilateral::Build_E() {
#pragma omp parallel for
    for (int index = 0; index < (signed)data_manager->num_bilaterals; index++) {
        data_manager->host_data.E[index + data_manager->num_unilaterals] = 0;
    }
}

void ChConstraintBilateral::Build_D() {
    // Grab the list of all bilateral constraints present in the system
    // (note that this includes possibly inactive constraints)
    std::vector<ChConstraint*>& mconstraints = data_manager->system_descriptor->GetConstraints();

    // Loop over the active constraints and fill in the rows of the Jacobian,
    // taking into account the type of each constraint.
    SubMatrixType D_b_T = _DBT_;

    //#pragma omp parallel for
    for (int index = 0; index < (signed)data_manager->num_bilaterals; index++) {
        int cntr = data_manager->host_data.bilateral_mapping[index];
        int type = data_manager->host_data.bilateral_type[cntr];
        int row = index;

        switch (type) {
            case BilateralType::BODY_BODY: {
                ChConstraintTwoBodies* mbilateral = (ChConstraintTwoBodies*)(mconstraints[cntr]);

                int idA = ((ChBody*)((ChVariablesBody*)(mbilateral->GetVariables_a()))->GetUserData())->GetIndex();
                int idB = ((ChBody*)((ChVariablesBody*)(mbilateral->GetVariables_b()))->GetUserData())->GetIndex();
                int colA = idA * 6;
                int colB = idB * 6;

                D_b_T(row, colA + 0) = mbilateral->Get_Cq_a()(0);
                D_b_T(row, colA + 1) = mbilateral->Get_Cq_a()(1);
                D_b_T(row, colA + 2) = mbilateral->Get_Cq_a()(2);

                D_b_T(row, colA + 3) = mbilateral->Get_Cq_a()(3);
                D_b_T(row, colA + 4) = mbilateral->Get_Cq_a()(4);
                D_b_T(row, colA + 5) = mbilateral->Get_Cq_a()(5);

                D_b_T(row, colB + 0) = mbilateral->Get_Cq_b()(0);
                D_b_T(row, colB + 1) = mbilateral->Get_Cq_b()(1);
                D_b_T(row, colB + 2) = mbilateral->Get_Cq_b()(2);

                D_b_T(row, colB + 3) = mbilateral->Get_Cq_b()(3);
                D_b_T(row, colB + 4) = mbilateral->Get_Cq_b()(4);
                D_b_T(row, colB + 5) = mbilateral->Get_Cq_b()(5);
            } break;

            case BilateralType::SHAFT_SHAFT: {
                ChConstraintTwoGeneric* mbilateral = (ChConstraintTwoGeneric*)(mconstraints[cntr]);

                int idA = ((ChVariablesShaft*)(mbilateral->GetVariables_a()))->GetShaft()->GetIndex();
                int idB = ((ChVariablesShaft*)(mbilateral->GetVariables_b()))->GetShaft()->GetIndex();

                int colA = data_manager->num_rigid_bodies * 6 + idA;
                int colB = data_manager->num_rigid_bodies * 6 + idB;

                D_b_T(row, colA) = mbilateral->Get_Cq_a()(0);
                D_b_T(row, colB) = mbilateral->Get_Cq_b()(0);
            } break;

            case BilateralType::SHAFT_BODY: {
                ChConstraintTwoGeneric* mbilateral = (ChConstraintTwoGeneric*)(mconstraints[cntr]);

                int idA = ((ChVariablesShaft*)(mbilateral->GetVariables_a()))->GetShaft()->GetIndex();
                int idB = ((ChBody*)((ChVariablesBody*)(mbilateral->GetVariables_b()))->GetUserData())->GetIndex();

                int colA = data_manager->num_rigid_bodies * 6 + idA;
                int colB = idB * 6;

                D_b_T(row, colA) = mbilateral->Get_Cq_a()(0);

                D_b_T(row, colB + 0) = mbilateral->Get_Cq_b()(0);
                D_b_T(row, colB + 1) = mbilateral->Get_Cq_b()(1);
                D_b_T(row, colB + 2) = mbilateral->Get_Cq_b()(2);

                D_b_T(row, colB + 3) = mbilateral->Get_Cq_b()(3);
                D_b_T(row, colB + 4) = mbilateral->Get_Cq_b()(4);
                D_b_T(row, colB + 5) = mbilateral->Get_Cq_b()(5);
            } break;

            case BilateralType::SHAFT_SHAFT_SHAFT: {
                ChConstraintThreeGeneric* mbilateral = (ChConstraintThreeGeneric*)(mconstraints[cntr]);
                int idA = ((ChVariablesShaft*)(mbilateral->GetVariables_a()))->GetShaft()->GetIndex();
                int idB = ((ChVariablesShaft*)(mbilateral->GetVariables_b()))->GetShaft()->GetIndex();
                int idC = ((ChVariablesShaft*)(mbilateral->GetVariables_c()))->GetShaft()->GetIndex();

                int colA = data_manager->num_rigid_bodies * 6 + idA;
                int colB = data_manager->num_rigid_bodies * 6 + idB;
                int colC = data_manager->num_rigid_bodies * 6 + idC;

                D_b_T(row, colA) = mbilateral->Get_Cq_a()(0);
                D_b_T(row, colB) = mbilateral->Get_Cq_b()(0);
                D_b_T(row, colC) = mbilateral->Get_Cq_c()(0);
            } break;

            case BilateralType::SHAFT_SHAFT_BODY: {
                ChConstraintThreeGeneric* mbilateral = (ChConstraintThreeGeneric*)(mconstraints[cntr]);
                int idA = ((ChVariablesShaft*)(mbilateral->GetVariables_a()))->GetShaft()->GetIndex();
                int idB = ((ChVariablesShaft*)(mbilateral->GetVariables_b()))->GetShaft()->GetIndex();
                int idC = ((ChBody*)((ChVariablesBody*)(mbilateral->GetVariables_c()))->GetUserData())->GetIndex();

                int colA = data_manager->num_rigid_bodies * 6 + idA;
                int colB = data_manager->num_rigid_bodies * 6 + idB;
                int colC = idC * 6;

                D_b_T(row, colA) = mbilateral->Get_Cq_a()(0);
                D_b_T(row, colB) = mbilateral->Get_Cq_b()(0);

                D_b_T(row, colC + 0) = mbilateral->Get_Cq_c()(0);
                D_b_T(row, colC + 1) = mbilateral->Get_Cq_c()(1);
                D_b_T(row, colC + 2) = mbilateral->Get_Cq_c()(2);

                D_b_T(row, colC + 3) = mbilateral->Get_Cq_c()(3);
                D_b_T(row, colC + 4) = mbilateral->Get_Cq_c()(4);
                D_b_T(row, colC + 5) = mbilateral->Get_Cq_c()(5);
            } break;
        }
    }
}

void ChConstraintBilateral::GenerateSparsity() {
    // Grab the list of all bilateral constraints present in the system
    // (note that this includes possibly inactive constraints)

    std::vector<ChConstraint*>& mconstraints = data_manager->system_descriptor->GetConstraints();

    // Loop over the active constraints and fill in the sparsity pattern of the
    // Jacobian, taking into account the type of each constraint.
    // Note that the data for a Blaze compressed matrix must be filled in increasing
    // order of the column index for each row. Recall that body states are always
    // before shaft states.

    CompressedMatrix<real>& D_b_T = data_manager->host_data.D_T;
    int off = data_manager->num_unilaterals;
    for (int index = 0; index < (signed)data_manager->num_bilaterals; index++) {
        int cntr = data_manager->host_data.bilateral_mapping[index];
        int type = data_manager->host_data.bilateral_type[cntr];
        int row = off + index;

        int col1;
        int col2;
        int col3;

        switch (type) {
            case BilateralType::BODY_BODY: {
                ChConstraintTwoBodies* mbilateral = (ChConstraintTwoBodies*)(mconstraints[cntr]);

                int idA = ((ChBody*)((ChVariablesBody*)(mbilateral->GetVariables_a()))->GetUserData())->GetIndex();
                int idB = ((ChBody*)((ChVariablesBody*)(mbilateral->GetVariables_b()))->GetUserData())->GetIndex();

                if (idA < idB) {
                    col1 = idA * 6;
                    col2 = idB * 6;
                } else {
                    col1 = idB * 6;
                    col2 = idA * 6;
                }

                D_b_T.append(row, col1 + 0, 1);
                D_b_T.append(row, col1 + 1, 1);
                D_b_T.append(row, col1 + 2, 1);
                D_b_T.append(row, col1 + 3, 1);
                D_b_T.append(row, col1 + 4, 1);
                D_b_T.append(row, col1 + 5, 1);

                D_b_T.append(row, col2 + 0, 1);
                D_b_T.append(row, col2 + 1, 1);
                D_b_T.append(row, col2 + 2, 1);
                D_b_T.append(row, col2 + 3, 1);
                D_b_T.append(row, col2 + 4, 1);
                D_b_T.append(row, col2 + 5, 1);
            } break;

            case BilateralType::SHAFT_SHAFT: {
                ChConstraintTwoGeneric* mbilateral = (ChConstraintTwoGeneric*)(mconstraints[cntr]);

                int idA = ((ChVariablesShaft*)(mbilateral->GetVariables_a()))->GetShaft()->GetIndex();
                int idB = ((ChVariablesShaft*)(mbilateral->GetVariables_b()))->GetShaft()->GetIndex();

                if (idA < idB) {
                    col1 = data_manager->num_rigid_bodies * 6 + idA;
                    col2 = data_manager->num_rigid_bodies * 6 + idB;
                } else {
                    col1 = data_manager->num_rigid_bodies * 6 + idB;
                    col2 = data_manager->num_rigid_bodies * 6 + idA;
                }

                D_b_T.append(row, col1, 1);
                D_b_T.append(row, col2, 1);
            } break;

            case BilateralType::SHAFT_BODY: {
                ChConstraintTwoGeneric* mbilateral = (ChConstraintTwoGeneric*)(mconstraints[cntr]);

                int idA = ((ChVariablesShaft*)(mbilateral->GetVariables_a()))->GetShaft()->GetIndex();
                int idB = ((ChBody*)((ChVariablesBody*)(mbilateral->GetVariables_b()))->GetUserData())->GetIndex();

                col1 = idB * 6;
                col2 = data_manager->num_rigid_bodies * 6 + idA;

                D_b_T.append(row, col1 + 0, 1);
                D_b_T.append(row, col1 + 1, 1);
                D_b_T.append(row, col1 + 2, 1);
                D_b_T.append(row, col1 + 3, 1);
                D_b_T.append(row, col1 + 4, 1);
                D_b_T.append(row, col1 + 5, 1);

                D_b_T.append(row, col2, 1);
            } break;

            case BilateralType::SHAFT_SHAFT_SHAFT: {
                ChConstraintThreeGeneric* mbilateral = (ChConstraintThreeGeneric*)(mconstraints[cntr]);
                std::vector<int> ids(3);
                ids[0] = ((ChVariablesShaft*)(mbilateral->GetVariables_a()))->GetShaft()->GetIndex();
                ids[1] = ((ChVariablesShaft*)(mbilateral->GetVariables_b()))->GetShaft()->GetIndex();
                ids[2] = ((ChVariablesShaft*)(mbilateral->GetVariables_c()))->GetShaft()->GetIndex();

                std::sort(ids.begin(), ids.end());
                col1 = data_manager->num_rigid_bodies * 6 + ids[0];
                col2 = data_manager->num_rigid_bodies * 6 + ids[1];
                col3 = data_manager->num_rigid_bodies * 6 + ids[2];

                D_b_T.append(row, col1, 1);
                D_b_T.append(row, col2, 1);
                D_b_T.append(row, col3, 1);
            } break;

            case BilateralType::SHAFT_SHAFT_BODY: {
                ChConstraintThreeGeneric* mbilateral = (ChConstraintThreeGeneric*)(mconstraints[cntr]);
                int idA = ((ChVariablesShaft*)(mbilateral->GetVariables_a()))->GetShaft()->GetIndex();
                int idB = ((ChVariablesShaft*)(mbilateral->GetVariables_b()))->GetShaft()->GetIndex();
                int idC = ((ChBody*)((ChVariablesBody*)(mbilateral->GetVariables_c()))->GetUserData())->GetIndex();

                col1 = idC * 6;
                if (idA < idB) {
                    col2 = data_manager->num_rigid_bodies * 6 + idA;
                    col3 = data_manager->num_rigid_bodies * 6 + idB;
                } else {
                    col2 = data_manager->num_rigid_bodies * 6 + idB;
                    col3 = data_manager->num_rigid_bodies * 6 + idA;
                }

                D_b_T.append(row, col1 + 0, 1);
                D_b_T.append(row, col1 + 1, 1);
                D_b_T.append(row, col1 + 2, 1);
                D_b_T.append(row, col1 + 3, 1);
                D_b_T.append(row, col1 + 4, 1);
                D_b_T.append(row, col1 + 5, 1);

                D_b_T.append(row, col2, 1);
                D_b_T.append(row, col3, 1);
            } break;
        }

        D_b_T.finalize(row);
    }
}
