/*
* Computes 2-beam intensities and derivatives.
* dFcs is overwritten with the dynamic derivatives and if grad_thickness is
* true - the dI_d_thickness is also appended
*/

static complex_t calc_2beam_ext(
  const miller::index<>& h,
  complex_t Ug_,
  af::shared<complex_t>& dFcs,
  FloatType thickness,
  cart_t const& K,
  mat3_t const& RMf,
  cart_t const& N,
  bool grad_thickness)
{
  FloatType Kn = N * K, Kl = K.length();
  cart_t K_g = K + RMf * cart_t(h[0], h[1], h[2]);
  FloatType K_gn = K_g * N;
  FloatType Sg = (Kl * Kl - K_g.length_sq()) / (2 * Kl);
  complex_t exp_k(0, 2 * scitbx::constants::pi * thickness);
  FloatType two_K_cos = 2 * Kl * K_gn / K_g.length();

  complex_t Ug = Ug_ / two_K_cos;
  complex_t A[4] = { 0,std::conj(Ug), Ug, Sg };
  FloatType v[2];

  math_utils<FloatType>::two_beam_eigen(&A[0], &v[0]);

  complex_t exps[2] = {
    std::exp(v[0] * exp_k),
    std::exp(v[1] * exp_k)
  };

  complex_t rv = A[2] * exps[0] * std::conj(A[0]) + A[3].real() * exps[1] * A[1].real();
  for (size_t i = 0; i < dFcs.size(); i++) {
    complex_t dFc = dFcs[i] / two_K_cos;
    complex_t Gl[4] = { // A^-1 by dFc
      std::conj(A[2]) * dFc,
      std::conj(A[0] * dFc),
      A[3] * dFc,
      A[1] * std::conj(dFc)
    };
    complex_t G[4] = { // G1 by A = A^-1 * dFc * A
      Gl[0] * A[0] + Gl[1] * A[2],
      Gl[0] * A[1].real() + Gl[1] * A[3].real(),
      Gl[2] * A[0] + Gl[3] * A[2],
      Gl[2] * A[1].real() + Gl[3] * A[3].real(),
    };
    complex_t X[4] = {
      exps[0] * exp_k,
        (exps[0] - exps[1]) / (v[0] - v[1]),
      0,
      exps[1] * exp_k
    };
    X[2] = X[1];
    complex_t V[4] = { G[0] * X[0], G[1] * X[1], G[2] * X[2], G[3] * X[3] };
    // V by first col of A* - A[1] and A[3] are real
    complex_t df[2] = {
      V[0] * std::conj(A[0]) + V[1] * A[1].real(),
      V[2] * std::conj(A[0]) + V[3] * A[1].real(),
    };
    // result - second row of A by df
    complex_t dp = A[2] * df[0] + A[3].real() * df[1];
    dFcs[i] = 2 * (rv.real() * dp.real() + rv.imag() * dp.imag());

    {
      cmat_t A_ct(af::mat_grid(2, 2)),
        m_dFc(af::mat_grid(2, 2)),
        m_V(af::mat_grid(2, 2)),
        m_A(af::mat_grid(2, 2));
      A_ct(0, 0) = std::conj(A[0]);
      A_ct(0, 1) = std::conj(A[2]);
      A_ct(1, 0) = std::conj(A[1]);
      A_ct(1, 1) = std::conj(A[3]);

      m_dFc(0, 1) = std::conj(dFc);
      m_dFc(1, 0) = dFc;


      m_A(0, 0) = A[0];
      m_A(0, 1) = A[1];
      m_A(1, 0) = A[2];
      m_A(1, 1) = A[3];

      cmat_t m_G = af::matrix_multiply(
        af::matrix_multiply(A_ct.const_ref(), m_dFc.const_ref()).const_ref(),
        m_A.const_ref());

      m_V(0, 0) = X[0] * m_G(0, 0);
      m_V(0, 1) = X[1] * m_G(0, 1);
      m_V(1, 0) = X[2] * m_G(1, 0);
      m_V(1, 1) = X[3] * m_G(1, 1);

      cmat_t m_D = af::matrix_multiply(
        af::matrix_multiply(m_A.const_ref(), m_V.const_ref()).const_ref(),
        A_ct.const_ref());
      complex_t dp_ = m_D(1, 0);
      std::norm(dp_);

      af::shared<complex_t> d_D(2), m_L(2), m_DL(2);
      d_D[0] = exps[0];
      d_D[1] = exps[1];

      m_L[0] = v[0];
      m_L[1] = v[1];

      m_DL[0] = exps[0] * exp_k;
      m_DL[1] = exps[1] * exp_k;

      m_D = calc_dS_dx(m_dFc,
        d_D, m_L, m_A, A_ct, m_DL);

      if (std::abs(m_D(1, 0) - dp) > 1e-3) {
        dp_ = m_D(1, 0);
        dFcs[i] = 2 * (rv.real() * dp_.real() + rv.imag() * dp_.imag());
      }

      FloatType denom = Sg * Sg + 4 * std::norm(Ug),
        denom_sq = sqrt(denom);
      complex_t d_E[2] = {
        2.0 * Ug / denom_sq,
        -2.0 * Ug / denom_sq
      };

      FloatType denom1 = (denom + Sg * denom_sq) * sqrt(2 * denom + 2 * Sg * denom_sq) * denom_sq;
      FloatType denom2 = (denom - Sg * denom_sq) * sqrt(2 * denom - 2 * Sg * denom_sq) * denom_sq;
      complex_t d_A[4] = {
        2.0 * Sg * (Sg * denom_sq + Sg * Sg + 2 * std::norm(Ug)) / denom1,
        -2.0 * Ug * Sg * (Sg + denom_sq) / denom1,
        -2.0 * Sg * (-Sg * denom_sq + Sg * Sg + 2 * std::norm(Ug)) / denom2,
        2.0 * Ug * Sg * (Sg - denom_sq) / denom2
      };
      //complex_t rv = A[2] * exps[0] * std::conj(A[0]) + A[3].real() * exps[1] * A[1].real();

      complex_t dd =
        d_A[2] * exps[0] * std::conj(A[0]) +
        A[2] * exps[0] * exp_k * d_E[0] * std::conj(A[0]) +
        A[2] * exps[0] * d_A[0] +

        d_A[3] * exps[1] * A[1] +
        A[3] * exps[1] * exp_k * d_E[1] * A[1] +
        A[3] * exps[1] * d_A[1];

        FloatType eps = 1e-8;
        complex_t r_1 = calc_amp_2beam(h, Ug_ + eps, thickness, K, RMf, N);
        complex_t l_1 = calc_amp_2beam(h, Ug_ - eps, thickness, K, RMf, N);

        complex_t n_diff = (r_1 - l_1) / (2.0 * std::abs(eps));

        complex_t xxx = dFcs[i] * n_diff;
      dd /= 2 * K_gn;
      dd *= dFc;
      if (std::abs(dd - dp) > 1e-3) {
        dFcs[i] = 2 * (rv.real() * dd.real() + rv.imag() * dd.imag());
      }
      if (std::abs(xxx - dp) > 1e-3) {
        dFcs[i] = 2 * (rv.real() * xxx.real() + rv.imag() * xxx.imag());
      }
    }
  }
  if (grad_thickness) {
    const complex_t k_dt(0, 2 * scitbx::constants::pi);
    // diagonal_dt by first row of A* (dI/dThickness)
    complex_t im_dt[2] = {
      exps[0] * (v[0] * k_dt) * std::conj(A[0]),
      exps[1] * (v[1] * k_dt) * A[1].real()
    };
    complex_t dt = A[2] * im_dt[0] + A[3].real() * im_dt[1];
    dFcs.push_back(2 * (rv.real() * dt.real() + rv.imag() * dt.imag()));
  }
  return rv;
}
