#!/usr/bin/python3
# Control script for Octoprint status display
# 01.12.2020 by hattedsquirrel.net
# Project homepage: https://hattedsquirrel.net/projects/octoprint-status-display/
# Licensed under The GNU General Public License v3.0

import requests
import websocket
import json
try:
    import thread
except ImportError:
    import _thread as thread
import numbers
import signal
from time import sleep
from luma.core.interface.serial import spi
from luma.core.render import canvas
from luma.lcd.device import st7735
from PIL import ImageFont
import RPi.GPIO as GPIO

#Octoprint
host_ws="127.0.0.1:5000" # Websocket address of octoprint <ip>:<port>
host_rest="127.0.0.1:80" # Address of REST API <ip>:<port>
user="octodisplay" # The user must have "Status" permission
password="password"

# Display
# luma uses BCM gpio numbering. So this scheme has also to be used later on for the fan and overtemp control
diface = spi(port=0, device=0, gpio_DC=25, gpio_RST=24)
device = st7735(diface)
device.cleanup = None # Don't reset display when terminating

# Bed cooling fan
# Turns on a bed cooling fan after the print finished. Right now this function doesn't care for the actual
# printer status. Instead it turns on, when octoprint requests a bed temperature below offtemp and the
# bed temperature is still above offtemp. Fan switches off, once the bed temperature falls below offtemp
# or a bed temperature above offtemp is set (e.g. by the next print). I found this to work better than
# including the diverse printe states.
# If you want the fans to stop once the bed has cooled down, set offtemp slightly above the highest room
# temperature you are expecting to see.
fan = {
    'gpio': 20,    # BCM gpio numbering. -1 to disable fan control
    'offtemp': 31, # Temperate below which the fan should turn on
    'polarity': 1, # 1=set IO high to activate the fan(s); 0=set IO low to activate the fan(s)
    'on': 0,       # Internal state variable. Keep it at 0
}

# Overtemp emergency shutdown (killswitch)
# A relay connected to a rPi gpio pin is used to kill the printers power if the bed or print head gets
# too hot. Just one additional safety net in case the printer firmware hangs or the hardware fails.
# This script does not reset the GPIO pin once the function got trigered. You can manually do this with
# any other program or simply restart your pi.
overtemp = {
    'gpio': 17,     # BCM gpio numbering. -1 to disable emergeny shutdown
    'polarity': 1,  # 1=set IO high to kill printer; 0=set IO low to kill printer
    'maxbed': 120,  # Maximum bed temperature
    'maxtool': 255, # Maximum tool temperature
}

# Color mapping of printer states
state_green=['Starting', 'Printing', 'Sending', 'Resuming', 'Finishing', 'Operational']
state_red=['Cancelling', 'Error']

# Font slection
font1  = ImageFont.truetype("DejaVuSans-Bold.ttf", 16, encoding="unic")
font2  = ImageFont.truetype("DejaVuSans.ttf",  14, encoding="unic")
#fontM1 = ImageFont.truetype("DejaVuSansMono-Bold.ttf", 24, encoding="unic")
fontM1 = ImageFont.truetype("Ubuntu-B.ttf", 34, encoding="unic")
#fontM1 = ImageFont.truetype("DejaVuSans-Bold.ttf", 32, encoding="unic")
fontM2 = ImageFont.truetype("Ubuntu-R.ttf", 18, encoding="unic")
#fontM2 = ImageFont.truetype("DejaVuSans.ttf", 18, encoding="unic")

# ----------------------------
# End of configuration section
# ----------------------------

displaydata = {
    'currentZ': None,
    'filename': None,
    'filenamelast': -1, #Used for internal state
    'filenamebroken': "-", #Used for internal state
    'filenamelines': 0, #used for internal state
    'completion': None,
    'filesize': None,
    'filepos': None,
    'printTimeTotal': None,
    'printTimePast': None,
    'printTimeLeft': None,
    'state': "unknown",
    'tempBedActual': None,
    'tempBedTarget': None,
    'tempToolActual': None,
    'tempToolTarget': None,
}


kill_now = False
ws = None

class Hasher(dict): #removes the need to check for the existence of every key before accessing it
    # https://stackoverflow.com/a/3405143/190597
    def __missing__(self, key):
        value = self[key] = type(self)()
        return value

# Format a number into a string. Replaces digits with "-" signs to indicate unknown values.
def format_number(val,format):
    if isinstance(val, numbers.Number):
        text = format % (val)
    else:
        text = format % (0)
        text = text.replace("0","-")
    #return text.replace(" ",u"\u2007") #"figure space" = blank, but with the width of one digit
    return text.replace(" ","  ") #Ubuntu font doesn't have "figure space" char. User double normal spaces instead

# Format time values. Displays "-:--:--" for unknown values.
def format_time(seconds):
    if isinstance(seconds, numbers.Number):
        text = '%2d:%02d:%02d' % (seconds / 3600, seconds / 60 % 60, seconds % 60)
    else:
        text = ' -:--:--'
    #return text.replace(" ",u"\u2007") #"figure space" = blank, but with the width of one digit
    return text.replace(" ","  ") #Ubuntu font doesn't have "figure space" char. User double normal spaces instead

# Break line if longer than display width
def break_text(draw,position,text,font,maxwidth):
    for pos in range(len(text)):
        if draw.textsize(text[:pos],font=font)[0] > maxwidth:
            newtext = text[:pos-1]+ "\n" + text[pos-1:]
            return (newtext, 2)
    return (text, 1)

# Output data on display
def update_display(displaydata):
    global font1, font2, fontM1, fontM2
    global device
    #console output for debugging:
    #print("updating");
    #print(json.dumps(displaydata,sort_keys=True, indent=2))
    #print(displaydata['filename'])
    #print(displaydata['state'])
    #print('-' + format_time(displaydata['printTimeLeft']))
    #print(format_number(displaydata['currentZ'],"%6.2f") + " mm  " + format_number(displaydata['completion'],"%3d") + " %")
    #print(format_time(displaydata['printTimePast']) + " of " + format_time(displaydata['printTimeTotal']))
    #print("Bed: %3d|%3d Tool: %3d|%3d" % (displaydata['tempBedActual'], displaydata['tempBedTarget'], displaydata['tempToolActual'], displaydata['tempToolTarget']))

    with canvas(device, dither=False) as draw:
        dwidth=device.width
        y = 0

        #Colors
        fill="#00ff00"
        #green: Starting* Printing Sending* Resuming Finishing
        #red: Cancelling Error*
        if isinstance(displaydata['state'],str):
            text = displaydata['state'].partition(' ')[0]
            if text in state_green: fill="#00ff00"
            elif text in state_red: fill="#ff0000"
            else: fill="#ffffff"

        #Filename. Break over two lines.
        #Save CPU cycles by only recalculating the line break when name actually changed
        if displaydata['filename']!=displaydata['filenamelast']:
            text = displaydata['filename'] if isinstance(displaydata['filename'],str) else "-"
            text, lines = break_text(draw,(0,y), text, font1, dwidth)
            displaydata['filenamelast']=displaydata['filename']
            displaydata['filenamebroken']=text
            displaydata['filenamelines']=lines
        else:
            text=displaydata['filenamebroken']
            lines=displaydata['filenamelines']
        draw.multiline_text((0,y), text, fill=fill, font=font1, spacing=0)
        y+=lines*16

        #Printer status
        draw.fontmode="1"
        draw.text((0,y), displaydata['state'], fill="#ff2020", font=font2)
        draw.fontmode="L"
        y+=12
        if lines<=1: y+=8

        #Remaining time. Centered
        text = '-' + format_time(displaydata['printTimeLeft'])
        width = draw.textsize(text, font=fontM1)[0]
        draw.text(((dwidth-width)/2,y), text, fill=fill, font=fontM1)
        y+=32
        if lines<=1: y+=8

        #Progess
        draw.text((0,y), format_number(displaydata['currentZ'],"%4.2f")+" mm", fill="white", font=fontM2)

        text = format_number(displaydata['completion'], "%d") + " %"
        width = draw.textsize(text,font=fontM2)[0]
        draw.text((dwidth-width,y), text, fill="white", font=fontM2)
        y+=16

        #Elapsed and remaining time. Center is shorter than display
        if isinstance(displaydata['printTimeTotal'], numbers.Number) and displaydata['printTimeTotal']>=60*60*10:
            text=" of "
        else:
            text="  of"
        text = format_time(displaydata['printTimePast']) + text + format_time(displaydata['printTimeTotal'])
        width = draw.textsize(text,font=fontM2)[0]
        width = 0 if width>dwidth else (dwidth-width)/2
        draw.text((width,y), text, fill="white", font=fontM2)
        y+=16

        #Bed and extruder temps
        text = "B:" + format_number(displaydata['tempBedActual'],"%3d") + "|" + format_number(displaydata['tempBedTarget'],"%3d")
        draw.text((0,y), text, fill="white", font=fontM2);

        text = "E:" + format_number(displaydata['tempToolActual'],"%3d") + "|" + format_number(displaydata['tempToolTarget'],"%3d")
        width = draw.textsize(text,font=fontM2)[0]
        draw.text((dwidth-width,y), text, fill="white", font=fontM2)
        y+=16

# Fan handling
def setfan(on):
    global fan
    print("Fan on" if on else "Fan off")
    if fan['gpio']>=0:
        GPIO.output(fan['gpio'], GPIO.HIGH if on==fan['polarity'] else GPIO.LOW)
    fan['on']=1 if on else 0

# Killswitch handling
def overtemp_shutdown():
    if overtemp['gpio']>=0:
        GPIO.output(overtemp['gpio'], GPIO.HIGH if overtemp['polarity'] else GPIO.LOW)
    print('Emergency shutdown activated!')
    print('Bed temp: ' + format_number(displaydata['tempBedActual'],"%3d"))
    print('Tool temp: ' + format_number(displaydata['tempToolActual'],"%3d"))

# Parse new message from octoprint
def on_message(ws, message):
    global displaydata
    jdata=json.loads(message)
    #print(message)
    #print(json.dumps(jdata,sort_keys=True, indent=2))
    if not 'current' in jdata:
        return

    # copy data over
    jdata=Hasher(jdata['current'])

    displaydata['currentZ']   = jdata['currentZ']
    displaydata['filename']   = jdata['job']['file']['name']
    if isinstance(displaydata['filename'],str):
        displaydata['filename'] = displaydata['filename'].rsplit('.',1)[0]
    displaydata['completion'] = jdata['progress']['completion']
    displaydata['filesize']   = jdata['job']['file']['size']
    displaydata['filepos']    = jdata['progress']['filepos']
    displaydata['printTimeTotal'] = jdata['job']['estimatedPrintTime']
    displaydata['printTimePast']  = jdata['progress']['printTime']
    displaydata['printTimeLeft']  = jdata['progress']['printTimeLeft']
    displaydata['state'] = jdata['state']['text']
#    print(json.dumps(jdata['state'],sort_keys=True, indent=2))
#    if jdata['state']['flags']['ready']: displaydata['state'] = 'ready'
#    displaydata.state = jdata['current']['job']['state']
#        "cancelling": false,
#        "closedOrError": false,
#        "error": false,
#        "finishing": false,
#        "operational": true,
#        "paused": false,
#        "pausing": false,
#        "printing": false,
#        "ready": true,
#        "resuming": false,
#        "sdReady": false
    if 'temps' in jdata and len(jdata['temps']) > 0:
        displaydata['tempBedActual']  = jdata['temps'][0]['bed']['actual']
        displaydata['tempBedTarget']  = jdata['temps'][0]['bed']['target']
        displaydata['tempToolActual'] = jdata['temps'][0]['tool0']['actual']
        displaydata['tempToolTarget'] = jdata['temps'][0]['tool0']['target']

    # display
    update_display(displaydata)

    # fan handling
    if fan['on'] and ( displaydata['tempBedActual']<=fan['offtemp'] or displaydata['tempBedTarget']>=fan['offtemp']):
        setfan(0)
    elif not(fan['on']) and displaydata['tempBedTarget']<fan['offtemp'] and displaydata['tempBedActual']>(fan['offtemp']+1):
        setfan(1)

    # emergency overtemp shutdown
    if displaydata['tempBedActual']>overtemp['maxbed'] or displaydata['tempToolActual']>overtemp['maxtool']:
        overtemp_shutdown()

# Websocket stuff
def on_error(ws, error):
    print("### websocket error ###")
    print(error)
    #ws.connect()

def on_close(ws):
    print("### websocket closed ###")

def on_open(ws):
    global kill_now
    print("### websocket opened ###")

    session=login()
    if session == "":
        ws.close()
        kill_now = True
        return

    def run(*args):
    # Authenticate. For now it looks like there is no auth needed. At least not when connecting from localhost.
        ws.send('{"auth":"' + user + ':' + session + '"}')
        ws.send('{"throttle": 2}') # lower the send rate of octoprint to conserve compute power
    thread.start_new_thread(run, ())

def login():
    print("Logging in as user " + user)
    cont=1
    while cont and not kill_now:
        cont = 0
        try:
            response = requests.post('http://' + host_rest + '/api/login', json={'user': user, 'pass': password})
        except:
            cont=1
            print("Conenction to REST API failed. Retrying...")
            sleep(2)

    print("Status code: ", response.status_code)
    if response.ok and 'session' in response.json():
        return response.json()['session']
    else:
        return ""


# React to Ctrl+C, KILL and INT signals
def exit_gracefully(self,signum):
    global ws, kill_now
    ws.close()
    kill_now = True

# Main
if __name__ == "__main__":
    signal.signal(signal.SIGINT, exit_gracefully)
    signal.signal(signal.SIGTERM, exit_gracefully)

    if fan['gpio']>=0: GPIO.setup(fan['gpio'], GPIO.OUT)
    if overtemp['gpio']>=0: GPIO.setup(overtemp['gpio'], GPIO.OUT)

    login()

    websocket.enableTrace(True)
    ws = websocket.WebSocketApp("ws://" + host_ws + "/sockjs/websocket",
                              on_message = on_message,
                              on_error = on_error,
                              on_close = on_close)
    ws.on_open = on_open
    update_display(displaydata)
    while not kill_now:
        ws.run_forever()
        if not kill_now: sleep(2) # Retry

