import serial
import threading
import time

class Frame():
    def __init__(self, time, data):
        self.time = time
        self.data = data

class GridEye():
    def __init__(self, serialPort, baudrate):
        self.port = serial.Serial(serialPort, baudrate)
        self.frame1 = None
        self.frame2 = None
        self.reading = True
        self.distance = -1
        self.thread = threading.Thread(target = self.reader)
        self.thread.setDaemon(True)
        self.lock = threading.Lock()

    def start(self):
        self.port.reset_input_buffer()
        self.thread.start()
        
    def stop(self):
        self.reading = False
        self.thread.join()
        
    def reader(self):
        while (self.reading):
            line = b''
            while (self.reading):
                c = self.port.read()
                if c == b'\n':
                    break
                line += c
            #line = self.port.readline()#.decode('utf-8')
            # if line:
                # print (line)
                # time.sleep(0.01)
            # if self.port.in_waiting > 0:
                # print (self.port.in_waiting)
            if b':' in line:
                try:
                    tag = line.decode('utf-8').split(':')[0]
                    
                    
                    if 'Distance' in tag:
                        dist = float(line.decode('utf-8').split(':')[1])
                        if dist > 200.0:
                            dist = 200.0
                        self.lock.acquire()
                        self.distance = dist
                        self.lock.release()
                    else:
                        values = [int(x, 16)*0.25 for x in line.decode('utf-8').split(':')[1].split()]
                        if len(values) == 64:
                            #print (data)
                            data = []
                            for i in range(8):
                                data.append(values[i*8:i*8+8])
                            self.lock.acquire()
                            if '105' in tag:
                                self.frame1 = Frame(time.time(), data)
                            else:
                                self.frame2 = Frame(time.time(), data)
                            self.lock.release()
                        else:
                            print ('something wrong', len(data))
                except Exception as e:
                    print (e)
              
        
if __name__ == '__main__':
    import cv2
    import numpy as np
    import math
    import json
    def exponential(img, value):
        tmp = cv2.pow(img.astype(np.double), value)*(255.0/(255.0**value))
        return tmp.astype(np.uint8)
        
    SIZE = 128
    AVERAGE_FRAME = 10
    distanceBetweenSensors_w = 2.6 #cm
    distanceBetweenSensors_h = 2.6 #cm
    distance2Object = 60.0 #cm
    ADJUST_BACK = 5
    EXPONENTAL_VALUE = 0.4
    PRODUCTION_THRESHOLD = 100
    MIN_EXIST_TIME = 0.5
    W_ARRAY = np.array([list(range(SIZE*2)) for x in range(SIZE*2)])
    H_ARRAY = np.array([[x]*(SIZE*2) for x in range(SIZE*2)])

    grideye = GridEye('COM18', 115200)
    grideye.start()
    grideye2 = GridEye('COM24', 115200)
    grideye2.start()
    
    # distanceSensor = Distance('COM18', 9600)
    # distanceSensor.start()
    
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    videoWriter = cv2.VideoWriter('output.avi', fourcc, 10.0, (SIZE*4,SIZE*4))
    siftVideoWriter = cv2.VideoWriter('sift.avi', fourcc, 10.0, (SIZE*2,SIZE*1))
    cv2.imshow('sample', np.zeros((SIZE*3,SIZE*2), np.uint8))
    cnt = 0
    avers = []
    hasPos = False
    endTime = 0
    startTime = 0
    while True:
        if grideye.frame1 and grideye.frame2 and grideye2.frame1 and grideye2.frame2:
            grideye.lock.acquire()
            grideye2.lock.acquire()
            frames = [grideye.frame1, grideye.frame2, grideye2.frame1, grideye2.frame2]
            grideye.frame1 = None
            grideye.frame2 = None
            grideye2.frame1 = None
            grideye2.frame2 = None
            distance2Object = grideye.distance + grideye2.distance + 1
            print (distance2Object)
            if distance2Object <= 0:
                distance2Object = 200
            grideye2.lock.release()
            grideye.lock.release()
            with open('log.txt', 'a') as f:
                f.write(json.dumps(frames[0].time)+'\n')
                for frame in frames:
                    f.write(json.dumps(frame.data)+'\n')
                #print (json.dumps(frames))
            imgs = []
            for frame in frames:
                img = (np.array(frame.data)-15)*10
                img = cv2.resize(img.astype(np.uint8), (SIZE,SIZE), interpolation = cv2.INTER_LINEAR) # INTER_LINEAR, INTER_CUBIC
                imgs.append(img)
                avers.append(np.zeros((SIZE,SIZE), np.uint16))
            
            
            if cnt < AVERAGE_FRAME:
                cnt += 1
                for i in range(len(imgs)):
                    avers[i] += imgs[i]
                if cnt == AVERAGE_FRAME:
                    for i in range(len(avers)):
                        avers[i] = avers[i]/AVERAGE_FRAME
                        avers[i] = avers[i].astype(np.uint8)
                        avers[i] += ADJUST_BACK
                continue
                
            for i in range(len(imgs)):
                imgs[i] = cv2.subtract(imgs[i], avers[i])
            print ('xdd')
            
            out = np.full((SIZE*4, SIZE*4), 255, dtype=np.uint16)
            out[:SIZE, :SIZE] = imgs[0]
            out[:SIZE, SIZE:SIZE*2] = imgs[1]
            out[SIZE:SIZE*2, :SIZE] = imgs[2]
            out[SIZE:SIZE*2, SIZE:SIZE*2] = imgs[3]
            '''
            try:
                overlap_w = int(SIZE - (distanceBetweenSensors_w / (2*distance2Object*math.tan(30.0/180.0*math.pi))) * SIZE)
            except:
                overlap_w = 0
            if overlap_w < 0:
                overlap_w = 0
                
            try:
                overlap_h = int(SIZE - (distanceBetweenSensors_h / (2*distance2Object*math.tan(30.0/180.0*math.pi))) * SIZE)
            except:
                overlap_h = 0
            if overlap_h < 0:
                overlap_h = 0
            
            tmp = np.zeros((SIZE, SIZE*2-overlap_w), dtype=np.uint16)
            tmp[:, :SIZE] = imgs[0]
            tmp[:, -SIZE:] += imgs[1]
            tmp[:, (SIZE-overlap_w): SIZE] = tmp[:, (SIZE-overlap_w): SIZE]/2
            
            tmp2 = np.zeros((SIZE, SIZE*2-overlap_w), dtype=np.uint16)
            tmp2[:, :SIZE] = imgs[2]
            tmp2[:, -SIZE:] += imgs[3]
            tmp2[:, (SIZE-overlap_w): SIZE] = tmp2[:, (SIZE-overlap_w): SIZE]/2
            
            merge = np.zeros((SIZE*2-overlap_h, SIZE*2-overlap_w), dtype=np.uint16)
            merge[:SIZE, :] = tmp
            merge[-SIZE:, :] += tmp2
            merge[(SIZE-overlap_h):SIZE, :] = merge[(SIZE-overlap_h):SIZE, :]/2
            # merge = exponential(merge, EXPONENTAL_VALUE)
            
            
            offset_w = int(overlap_w/2)
            offset_h = int(overlap_h/2)
            print (SIZE*2+offset_h, SIZE*4-overlap_h+offset_h, offset_w, SIZE*2-overlap_w+offset_w)
            out[SIZE*2+offset_h:SIZE*4-overlap_h+offset_h, offset_w: SIZE*2-overlap_w+offset_w] = merge
            
            
            maxProduct = 0
            overlap_w = 0
            for i in range(80, 128):
                product = sum(imgs[0][:,SIZE-i:].astype(np.uint32)*imgs[1][:,:i].astype(np.uint32))
                product += sum(imgs[2][:,SIZE-i:].astype(np.uint32)*imgs[3][:,:i].astype(np.uint32))
                product = sum(product) / len(product)
                if product > maxProduct:
                    maxProduct = product
                    overlap_w = i
                    
            tmp = maxProduct
            maxProduct = 0
            overlap_h = 0
            for i in range(80, 128):
                product = sum(imgs[0][SIZE-i:, :].astype(np.uint32)*imgs[2][:i,:].astype(np.uint32))
                product += sum(imgs[1][SIZE-i:, :].astype(np.uint32)*imgs[3][:i,:].astype(np.uint32))
                product = sum(product) / len(product)
                if product > maxProduct:
                    maxProduct = product
                    overlap_h = i
            maxProduct = (tmp + maxProduct)/2
            
            tmp = np.zeros((SIZE, SIZE*2-overlap_w), dtype=np.uint16)
            tmp[:, :SIZE] = imgs[0]
            tmp[:, -SIZE:] += imgs[1]
            tmp[:, (SIZE-overlap_w): SIZE] = tmp[:, (SIZE-overlap_w): SIZE]/2
            
            tmp2 = np.zeros((SIZE, SIZE*2-overlap_w), dtype=np.uint16)
            tmp2[:, :SIZE] = imgs[2]
            tmp2[:, -SIZE:] += imgs[3]
            tmp2[:, (SIZE-overlap_w): SIZE] = tmp2[:, (SIZE-overlap_w): SIZE]/2
            
            merge = np.zeros((SIZE*2-overlap_h, SIZE*2-overlap_w), dtype=np.uint16)
            merge[:SIZE, :] = tmp
            merge[-SIZE:, :] += tmp2
            merge[(SIZE-overlap_h):SIZE, :] = merge[(SIZE-overlap_h):SIZE, :]/2
            
                
            offset_w = int(overlap_w/2)
            offset_h = int(overlap_h/2)
            out[SIZE*2+offset_h:SIZE*4-overlap_h+offset_h, SIZE*2+offset_w: SIZE*4-overlap_w+offset_w] = merge
            '''
            # offset = int(overlap2/2)
            # tmp = np.zeros((SIZE, SIZE*2-overlap2), dtype=np.uint16)
            # tmp[:, :SIZE] = img
            # tmp[:, -SIZE:] += img2
            # tmp[:, (SIZE-overlap2): SIZE] = tmp[:, (SIZE-overlap2): SIZE]/2
            # tmp = exponential(tmp, EXPONENTAL_VALUE)
            # out[SIZE*2:, offset: SIZE*2-overlap2+offset] = tmp
            
            
            out = out.astype(np.uint8)
            out = exponential(out, EXPONENTAL_VALUE)
            
            out = cv2.cvtColor(out,cv2.COLOR_GRAY2BGR)
            if False and maxProduct > PRODUCTION_THRESHOLD:
                print ('XDDDD',maxProduct)
                position = [0,0]
                rows,cols = merge.shape
                position[0] = sum(sum(H_ARRAY[:rows,:cols]*merge))/sum(sum(merge))
                position[1] = sum(sum(W_ARRAY[:rows,:cols]*merge))/sum(sum(merge))
                pos_w = distanceBetweenSensors_w/(SIZE-overlap_w)*position[0]
                pos_h = distanceBetweenSensors_h/(SIZE-overlap_h)*position[1]
                cv2.circle(out, (SIZE*2+offset_w+int(position[1]), SIZE*2+offset_h+int(position[0])), 10, (255,0,0), 5)
                if not hasPos:
                    startPos = [pos_w, pos_h]
                    startTime = frames[0].time
                    hasPos = True
                endPos = [pos_w, pos_h]
                endTime = frames[0].time
            elif hasPos:
                if endTime - startTime > MIN_EXIST_TIME:
                    print (startPos, endPos)
                    print ('speed:', ((endPos[0]-startPos[0])**2+(endPos[1]-startPos[1])**2)**0.5/(endTime - startTime))
                    print ('time:', endTime-startTime)
                hasPos = False
            if endTime - startTime > MIN_EXIST_TIME:
                speed = ((endPos[0]-startPos[0])**2+(endPos[1]-startPos[1])**2)**0.5/(endTime - startTime)
                cv2.putText(out, f'{speed:.2f}',
                        (0, SIZE*2),cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
                
            cv2.imshow('sample', out)
            videoWriter.write(out)
        key = cv2.waitKey(1)
        if key == ord('q'):
            break
        elif key == ord('c'):
            cv2.imwrite('out.jpg', out)
            with open('log_captured.txt', 'a') as f:
                f.write(json.dumps(frames[0].time)+'\n')
                for frame in frames:
                    f.write(json.dumps(frame.data)+'\n')
        time.sleep(0.001)
    grideye.stop()
    videoWriter.release()
    siftVideoWriter.release()
    cv2.destroyAllWindows()