
#include<jni.h>
#include<stdlib.h>
#include"z3.h"
#ifdef __cplusplus
extern "C" {
#endif

#ifdef __GNUC__
#if __GNUC__ >= 4
#define DLL_VIS __attribute__ ((visibility ("default")))
#else
#define DLL_VIS
#endif
#else
#define DLL_VIS
#endif

#if defined(__LP64__) || defined(_WIN64)

#define GETLONGAELEMS(T,OLD,NEW)                                   \
  T * NEW = (OLD == 0) ? 0 : (T*) jenv->GetLongArrayElements(OLD, NULL);
#define RELEASELONGAELEMS(OLD,NEW)                                 \
  if (OLD != 0) jenv->ReleaseLongArrayElements(OLD, (jlong *) NEW, JNI_ABORT);     

#define GETLONGAREGION(T,OLD,Z,SZ,NEW)                               \
  jenv->GetLongArrayRegion(OLD,Z,(jsize)SZ,(jlong*)NEW);             
#define SETLONGAREGION(OLD,Z,SZ,NEW)                               \
  jenv->SetLongArrayRegion(OLD,Z,(jsize)SZ,(jlong*)NEW)              

#else

#define GETLONGAELEMS(T,OLD,NEW)                                   \
  T * NEW = 0; {                                                   \
  jlong * temp = (OLD == 0) ? 0 : jenv->GetLongArrayElements(OLD, NULL); \
  unsigned int size = (OLD == 0) ? 0 :jenv->GetArrayLength(OLD);     \
  if (OLD != 0) {                                                    \
    NEW = (T*) (new int[size]);                                      \
    for (unsigned i=0; i < size; i++)                                \
      NEW[i] = reinterpret_cast<T>(temp[i]);                         \
    jenv->ReleaseLongArrayElements(OLD, temp, JNI_ABORT);            \
  }                                                                  \
  }                                                                    

#define RELEASELONGAELEMS(OLD,NEW)                                   \
  delete [] NEW;                                                     

#define GETLONGAREGION(T,OLD,Z,SZ,NEW)                              \
  {                                                                 \
    jlong * temp = new jlong[SZ];                                   \
    jenv->GetLongArrayRegion(OLD,Z,(jsize)SZ,(jlong*)temp);         \
    for (int i = 0; i < (SZ); i++)                                  \
      NEW[i] = reinterpret_cast<T>(temp[i]);                        \
    delete [] temp;                                                 \
  }

#define SETLONGAREGION(OLD,Z,SZ,NEW)                                \
  {                                                                 \
    jlong * temp = new jlong[SZ];                                   \
    for (int i = 0; i < (SZ); i++)                                  \
      temp[i] = reinterpret_cast<jlong>(NEW[i]);                    \
    jenv->SetLongArrayRegion(OLD,Z,(jsize)SZ,temp);                 \
    delete [] temp;                                                 \
  }

#endif

void Z3JavaErrorHandler(Z3_context c, Z3_error_code e)
{
  // Internal do-nothing error handler. This is required to avoid that Z3 calls exit()
  // upon errors, but the actual error handling is done by throwing exceptions in the
  // wrappers below.
}

DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_setInternalErrorHandler(JNIEnv * jenv, jclass cls, jlong a0)
{
  Z3_set_error_handler((Z3_context)a0, Z3JavaErrorHandler);
}


#include <assert.h>

struct JavaInfo {
  JNIEnv *jenv = nullptr;
  jobject jobj = nullptr;

  jmethodID push = nullptr;
  jmethodID pop = nullptr;
  jmethodID fresh = nullptr;
  jmethodID created = nullptr;
  jmethodID fixed = nullptr;
  jmethodID eq = nullptr;
  jmethodID final = nullptr;
  jmethodID decide = nullptr;

  Z3_solver_callback cb = nullptr;
};

struct ScopedCB {
  JavaInfo *info;
  ScopedCB(JavaInfo *_info, Z3_solver_callback cb): info(_info) {
    info->cb = cb;
  }
  ~ScopedCB() {
    info->cb = nullptr;
  }
};

static void push_eh(void* _p, Z3_solver_callback cb) {
  JavaInfo *info = static_cast<JavaInfo*>(_p);
  ScopedCB scoped(info, cb);
  info->jenv->CallVoidMethod(info->jobj, info->push);
}

static void pop_eh(void* _p, Z3_solver_callback cb, unsigned int number) {
  JavaInfo *info = static_cast<JavaInfo*>(_p);
  ScopedCB scoped(info, cb);
  info->jenv->CallVoidMethod(info->jobj, info->pop, number);
}

static void* fresh_eh(void* _p, Z3_context new_context) {
  JavaInfo *info = static_cast<JavaInfo*>(_p);
  return info->jenv->CallObjectMethod(info->jobj, info->fresh, (jlong)new_context); 
}

static void created_eh(void* _p, Z3_solver_callback cb, Z3_ast _e) {
  JavaInfo *info = static_cast<JavaInfo*>(_p);
  ScopedCB scoped(info, cb);
  info->jenv->CallVoidMethod(info->jobj, info->created, (jlong)_e);
}

static void fixed_eh(void* _p, Z3_solver_callback cb, Z3_ast _var, Z3_ast _value) {
  JavaInfo *info = static_cast<JavaInfo*>(_p);
  ScopedCB scoped(info, cb);
  info->jenv->CallVoidMethod(info->jobj, info->fixed, (jlong)_var, (jlong)_value);
}

static void eq_eh(void* _p, Z3_solver_callback cb, Z3_ast _x, Z3_ast _y) {
  JavaInfo *info = static_cast<JavaInfo*>(_p);
  ScopedCB scoped(info, cb);
  info->jenv->CallVoidMethod(info->jobj, info->eq, (jlong)_x, (jlong)_y);
}

static void final_eh(void* _p, Z3_solver_callback cb) {
  JavaInfo *info = static_cast<JavaInfo*>(_p);
  ScopedCB scoped(info, cb);
  info->jenv->CallVoidMethod(info->jobj, info->final);
}

static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast _val, unsigned bit, bool is_pos) {
  JavaInfo *info = static_cast<JavaInfo*>(_p);
  ScopedCB scoped(info, cb);
  info->jenv->CallVoidMethod(info->jobj, info->decide, (jlong)_val, bit, is_pos);
}

DLL_VIS JNIEXPORT jlong JNICALL Java_com_microsoft_z3_Native_propagateInit(JNIEnv *jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) {
  JavaInfo *info = new JavaInfo;

  info->jenv = jenv;
  info->jobj = jenv->NewGlobalRef(jobj);
  jclass jcls = jenv->GetObjectClass(info->jobj);
  info->push = jenv->GetMethodID(jcls, "pushWrapper", "()V");
  info->pop = jenv->GetMethodID(jcls, "popWrapper", "(I)V");
  info->fresh = jenv->GetMethodID(jcls, "freshWrapper", "(J)Lcom/microsoft/z3/Native$UserPropagatorBase;");
  info->created = jenv->GetMethodID(jcls, "createdWrapper", "(J)V");
  info->fixed = jenv->GetMethodID(jcls, "fixedWrapper", "(JJ)V");
  info->eq = jenv->GetMethodID(jcls, "eqWrapper", "(JJ)V");
  info->final = jenv->GetMethodID(jcls, "finWrapper", "()V");
  info->decide = jenv->GetMethodID(jcls, "decideWrapper", "(JIZ)V");

  if (!info->push || !info->pop || !info->fresh || !info->created || !info->fixed || !info->eq || !info->final || !info->decide) {
    assert(false);
  }

  Z3_solver_propagate_init((Z3_context)ctx, (Z3_solver)solver, info, push_eh, pop_eh, fresh_eh);

  return (jlong)info;
}

DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateDestroy(JNIEnv *jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo) {
  JavaInfo *info = (JavaInfo*)javainfo;
  info->jenv->DeleteGlobalRef(info->jobj);
  delete info;
}

DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterCreated(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) {
  Z3_solver_propagate_created((Z3_context)ctx, (Z3_solver)solver, created_eh);
}

DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterFinal(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) {
  Z3_solver_propagate_final((Z3_context)ctx, (Z3_solver)solver, final_eh);
}

DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterFixed(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) {
  Z3_solver_propagate_fixed((Z3_context)ctx, (Z3_solver)solver, fixed_eh);
}

DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterEq(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) {
  Z3_solver_propagate_eq((Z3_context)ctx, (Z3_solver)solver, eq_eh);
}

DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterDecide(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) {
  Z3_solver_propagate_decide((Z3_context)ctx, (Z3_solver)solver, decide_eh);
}

DLL_VIS JNIEXPORT jboolean JNICALL Java_com_microsoft_z3_Native_propagateConsequence(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, long num_fixed, jlongArray fixed, long num_eqs, jlongArray eq_lhs, jlongArray eq_rhs, jlong conseq) {
  JavaInfo *info = (JavaInfo*)javainfo;
  GETLONGAELEMS(Z3_ast, fixed, _fixed);
  GETLONGAELEMS(Z3_ast, eq_lhs, _eq_lhs);
  GETLONGAELEMS(Z3_ast, eq_rhs, _eq_rhs);
  bool retval = Z3_solver_propagate_consequence((Z3_context)ctx, info->cb, num_fixed, _fixed, num_eqs, _eq_lhs, _eq_rhs, (Z3_ast)conseq);
  RELEASELONGAELEMS(fixed, _fixed);
  RELEASELONGAELEMS(eq_lhs, _eq_lhs);
  RELEASELONGAELEMS(eq_rhs, _eq_rhs);
  return (jboolean) retval;
}

DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateAdd(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, jlong e) {
  JavaInfo *info = (JavaInfo*)javainfo;
  Z3_solver_callback cb = info->cb;
  if (cb)
    Z3_solver_propagate_register_cb((Z3_context)ctx, cb, (Z3_ast)e);
  else if (solver)
    Z3_solver_propagate_register((Z3_context)ctx, (Z3_solver)solver, (Z3_ast)e);
  else {
    assert(false);
  }
}


DLL_VIS JNIEXPORT jboolean JNICALL Java_com_microsoft_z3_Native_propagateNextSplit(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, jlong e, long idx, int phase) {
  JavaInfo *info = (JavaInfo*)javainfo;
  Z3_solver_callback cb = info->cb;
  return (jboolean) Z3_solver_next_split((Z3_context)ctx, cb, (Z3_ast)e, idx, Z3_lbool(phase));
}
