#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Very bare bones shell for driving a Livy session. Usage:
#
#   livy-shell url [option=value ...]
#
# Options are set directly in the session creation request, so they must match the names of fields
# in the CreateInteractiveRequest structure. Option values should be python-like objects (should be
# parseable by python's "eval" function; naked strings are allowed). For example:
#
#   kind=spark
#   jars=[ '/foo.jar', '/bar.jar' ]
#   conf={ foo : bar, 'spark.option' : opt_value }
#
# By default, a Spark (Scala) session is created.
#

import json
import readline
import signal
import sys
import time
import urlparse


class ControlCInterrupt(Exception):
  pass


def check(condition, msg, *args):
  if not condition:
    if args:
      msg = msg % args
    print >> sys.stderr, msg
    sys.exit(1)


def message(msg, *args):
  if args:
    msg = msg % args
  print(msg)


try:
  import requests
except ImportError:
  message("Unable to import 'requests' module, which is required by livy-shell.")
  sys.exit(1)


class LiteralDict(dict):
  def __missing__(self, name):
    return name


def request(method, uri, body):
  kwargs = { 'headers': { 'Content-Type' : 'application/json', 'X-Requested-By': 'livy' } }
  if body:
    kwargs['json'] = body
  resp = requests.request(method.upper(), urlparse.urljoin(url.geturl(), uri), **kwargs)
  resp.raise_for_status()
  if resp.status_code < requests.codes.multiple_choices and resp.status_code != requests.codes.no_content:
    return resp.json()
  return None


def get(uri):
  return request('GET', uri, None)


def post(uri, body):
  return request('POST', uri, body)


def delete(uri):
  return request('DELETE', uri, None)


def create_session():
  request = {
    "kind" : "spark"
  }
  for opt in sys.argv[2:]:
    check(opt.find('=') > 0, "Invalid option: %s.", opt)
    key, value = opt.split('=', 1)
    request[key] = eval(value, LiteralDict())

  return post("/sessions", request)


def wait_for_idle(sid):
  session = get("/sessions/%d" % (sid, ))
  while session['state'] == 'starting':
    message("Session not ready yet (%s)", session['state'])
    time.sleep(5)
    session = get("/sessions/%d" % (sid, ))

  if session['state'] != 'idle':
    raise Exception, "Session failed to start."


def monitor_statement(sid, s):
  cnt = 0
  while True:
    state = s['state']
    if state == 'available':
      output = s['output']
      status = output['status']
      if status == 'ok':
        result = output['data']
        text = result.get('text/plain', None)
        if text is None:
          message("Success (non-text result).")
        elif text.rstrip():
          message("%s", text)
      elif status == 'error':
        ename = output['ename']
        evalue = output['evalue']
        traceback = "\n".join(output.get('traceback', []))
        message("%s: %s", ename, evalue)
        if traceback:
          message("%s", traceback)
      else:
        message("Statement finished with unknown status '%s'.", status)
      break
    elif state == 'error':
      message("%s", s['error'])
      break
    else:
      if cnt == 10:
        message("(waiting for result...)")
        cnt = 0
      else:
        cnt += 1
      time.sleep(1)
      s = get("/sessions/%d/statements/%s" % (sid, s['id']))


def run_shell(sid, session_kind):
  prompt = "{} ({}) > ".format(session_kind, sid)
  def ctrl_c_handler(signal, frame):
    message("\nPlease type quit() to exit the livy shell.")
    raise ControlCInterrupt()
  signal.signal(signal.SIGINT, ctrl_c_handler)

  while True:
    try:
      cmd = raw_input(prompt)
      if cmd == "quit()":
        break
    except ControlCInterrupt:
      continue

    statement = post("/sessions/%d/statements" % (sid, ), { 'code' : cmd })
    monitor_statement(sid, statement)


#
# main()
#

check(len(sys.argv) > 1, "Missing arguments.")

url = urlparse.urlparse(sys.argv[1])
sid = -1

try:
  message("Creating new session...")
  session = create_session()
  sid = int(session['id'])
  message("New session (id = %d, kind = %s), waiting for idle state...", sid, session['kind'])
  wait_for_idle(sid)
  message("Session ready.")
  run_shell(sid, session.get('kind', 'spark'))
except EOFError:
  pass
finally:
  if sid != -1:
    delete("/sessions/%d" % (sid, ))
