import numpy as np
    
    
class PM1000():
    
    n = None

    def __init__(self, target='127.0.0.1'):
        
        if False:
            self.n = None
        elif (target=='SPI'):
            from RPI_SPI import RPI_SPI
            self.n = RPI_SPI()
        elif (target=='USB3'):
            from NovoptelUSB3 import NovoptelUSB3
            self.n = NovoptelUSB3()
            if self.n.DEVNO<0:
                self.n = None
        elif (target=='PyD3XX'):
            from NovoptelPyD3XX import NovoptelPyD3XX
            self.n = NovoptelPyD3XX()
            if self.n.DEVNO==None:
                self.n = None
        elif (target=='USB2'):
            from NovoptelUSB2 import NovoptelUSB2
            self.n = NovoptelUSB2()
            if self.n.d==None:
                self.n = None
        elif (target.startswith('UDP')):
            from NovoptelUDP import NovoptelUDP
            target = target.replace('UDP', '')
            self.n = NovoptelUDP(target, port=5024)
            if self.n.sock==None:
                self.n = None
        else:
            from NovoptelTCP import NovoptelTCP
            self.n = NovoptelTCP(target, port=5025)
        return
    
    def close(self):
        self.n.close()
        del(self.n)
        
    #######################    
    # Basic communication #
    #######################
        
    def read(self, addr: int):
        res = self.n.read(addr)
        return int(res)
    
        
    def write(self, addr: int, data: int):
        self.n.write(addr, data)
    
    
    ###########################    
    # General instrument data #
    ###########################
    
    def getfirmware(self):
        return "%04X" % self.n.read(512+128)
    
    def getserialnumber(self):
        return self.n.read(512+133)
    
    def getmoduletype(self):
        m = []
        for x in range(16):
            dummy = self.n.read(512+144+x)
            m.append(dummy >> 8)
            m.append(dummy & 0xFF)
        return bytes(m).decode()
    
    #########################    
    # Calibration functions #
    #########################
    
    def setATE(self, ATE: int):
        self.write(512+1, ATE)
        
    def getATE(self):
        return self.n.read(512+1)
    
    def setnormmode(self, norm_mode: int):
        # 0: Non-normalized
        # 1: Standard normalization
        # 2: Exact normalization3
        norm_mode = min(2, max(0, norm_mode))
        self.write(512+46, norm_mode)
        
    def getnormmode(self):
        return self.n.read(512+46)
    
    def setnormlevel(self, level_uW: int): # in microwatts
        # for example: PM1000.write(512+38, 2^11); % Normalize S1-S3 to 2^11=2048 µw
        self.write(512+46, round(level_uW))
        
    def setpowerleftshift(self, bits: int):
        self.write(512+74, bits)
        
    def getoptfrequency(self): # in THz
        return self.n.read(512+69)/100
    
    def setoptfrequency(self, frequency: int): # in THz
        self.write(512+69, round(frequency*100))
        
    def getmincalfreq(self): # in THz
        return self.n.read(512+70)/100
    
    def getmaxcalfreq(self): # in THz
        return self.n.read(512+71)/100
    
    def getminfreq(self): # in THz
        return self.n.read(512+67)/100
    
    def getmaxfreq(self): # in THz
        return self.n.read(512+68)/100
    
    def setswitchstate(self, state: int):
        self.write(512+430, state)
        
    def getswitchstate(self):
        return self.n.read(512+430)/100
    
    def gettemperature(self):
        return self.n.read(512+421)/16
    
    def settemperature(self, tmp: int):
         # setpoint of the heating element
         self.n.write(512+431, round(tmp*16));
         # enable the heating element
         self.n.write(512+432, 59362);
         # setpoint (max. 63°C) and hysteresis for the cooling fan
         hyst = 0;
         #temp = 63;
         temp = tmp+1;
         self.n.write(512+433, round(hyst*16)*2**10 + round(temp*16));
         
     
        
     
    #############################    
    # Data aquisition functions #
    #############################
    
    def getphotocurrents(self):
        res = np.empty(4)
        res[0] = self.n.read(512+2)
        res[1] = self.n.read(512+4) # latched version
        res[2] = self.n.read(512+6) # latched version
        res[3] = self.n.read(512+8) # latched version
        bitshift = self.n.read(512+9) # latched version
        res = res/2**bitshift
        return res
    
    def getStokes(self, norm: int):
        s = np.empty(4)
        if norm == 0: # non-normalized
            s[0] = self.n.read(512+38)/2**15 # power reference for S1,S2,S3
            s[1] = (self.n.read(512+42)-2**15)/2**15
            s[2] = (self.n.read(512+43)-2**15)/2**15
            s[3] = (self.n.read(512+44)-2**15)/2**15
        elif norm == 1: # Standard normalization
            s[0] = self.n.read(512+24)/2**15 # DOP
            s[1] = (self.n.read(512+28)-2**15)/2**15
            s[2] = (self.n.read(512+29)-2**15)/2**15
            s[3] = (self.n.read(512+30)-2**15)/2**15
        elif norm == 2: # Exact normalization
            s[0] = self.n.read(512+31)/2**15 # DOP
            s[1] = (self.n.read(512+35)-2**15)/2**15
            s[2] = (self.n.read(512+36)-2**15)/2**15
            s[3] = (self.n.read(512+37)-2**15)/2**15
        elif norm == 3: # in uW
            s[0] = self.n.read(512+10) + self.n.read(512+11)/2**16
            g = self.n.read(512+18)-2**15
            n = (self.n.read(512+19) ^ 2**15) / 2**16
            if g<0:
                s[1] = g-n
            else:
                s[1] = g+n
            g = self.n.read(512+20)-2**15
            n = (self.n.read(512+21) ^ 2**15) / 2**16
            if g<0:
                s[2] = g-n
            else:
                s[2] = g+n
            g = self.n.read(512+22)-2**15
            n = (self.n.read(512+23) ^ 2**15) / 2**16
            if g<0:
                s[3] = g-n
            else:
                s[3] = g+n
        else:
            print("invalid normalization value")
        return s
    
    def getDOP(self):
        return self.n.read(512+24)/2**15
    
    def getPower(self): # in Microwatt
        return self.n.read(512+10) + self.n.read(512+11)/2**16
    
    
    
    
    
    # when ATE>0, each timestamp refers to exactly one SOP.
    # instead of iterating through the data array, one can reshape and work with vectors, which is by far faster
    # return values:
    # Timestamps in ns as uint64
    # Stokes0 as uint16
    # Stokes2 as int16
    # Stokes1 as int16
    # Stokes3 as int16
    def Data2DynSOP_fast(self, data):
        
        numsamples = int(data.shape[0]/2)
        arr = data.reshape(numsamples, 8)
        
        # sanity check
        if (min(arr[0:,4])<65280): # somewhere in the data is not a timestamp where it should be
            return 0,0,0,0,0,0
        
        Stokes0 = arr[0:,0].astype(np.uint16)
        Stokes1 = arr[0:,1].astype(np.int16)
        Stokes2 = arr[0:,2].astype(np.int16)
        Stokes3 = arr[0:,3].astype(np.int16)
        
        d1 = arr[0:,4].astype(np.uint64) & 127
        d2 = arr[0:,5].astype(np.uint64)
        d3 = arr[0:,6].astype(np.uint64)
        d4 = arr[0:,7].astype(np.uint64)
        
        Timestamps = (d1 << 48) | (d2 << 32) | (d3 << 16) | d4 
        
        return Timestamps, Stokes0, Stokes1, Stokes2, Stokes3, 1
    
    
        
    # * Bits 63..56 : 11111111 (Timestamp indicator)
    # * Bit 55: reserved
    # * Bits 54..30 : pps_counter
    # * Bits 29..0: pps_counter_ns
    #
    # But if sampling time is 10 ns and SOP changes surpassed the threshold
    # within the last 10 ns, the timestamp is skipped.
    #
    # Once after every rising edge of the PPS signal, a second timestamp is written that contains the
    # nanoseconds count between the last two pps ticks, measured with the internal (slightly wrong) clock.
    # Since we know that there should be 1e9 nanoseconds between the ticks, we can use this value to calculate
    # a correction factor for the *PPS_counter_ns* values in the second before.
    #
    # * Bits 63..56 : 11111111 (Timestamp indicator)
    # * Bit 55..30: reserved
    # * Bits 29..0: last_pps_count
    def Data2DynSOP_1PPS(self, data, pps_sync_ext, addoffset):
        if data[0:,0].max()>=65280:  # there are timestamps in the data
        
            if ((pps_sync_ext==0) and (addoffset==0)): # we can use the fast routine
                Timestamps_s, Stokes0, Stokes1, Stokes2, Stokes3, ok = self.Data2DynSOP_fast(data)
                if (ok>0): # fast routine succeeded
                    return Timestamps_s, Stokes0, Stokes1, Stokes2, Stokes3
            
            currind = 0
            last_ts_index = -1
            last_tscorr_index = -1
           
            Stokes0 = []
            Stokes1_raw = []
            Stokes2_raw = []
            Stokes3_raw = []
            
            ts_diff_pps = []
            ts_diff_ns = []
            #ts_diff_ns_orig = []
            
            ppskorrfact = 1
            ppskorrfactlist = []
            
            ii = 0
            
            if data[0,0]>=65280: # If first entry in memory is a timestamp, no data belongs to it.
                ii = 1
                last_ts_index = 0
                
            while (ii<data.shape[0]):
                
                #print(ii)
                
                # search for the next timestamp
                if data[ii,0]<65280: # this is SOP data
                    ii = ii + 1
                else: # this is the next timestamp
                    
                    if (pps_sync_ext>0):
                        d1 = data[ii,0] & 127
                        if (addoffset>0):
                            d2 = data[ii,1] - 2**15 
                            if (d2<0): 
                                d2 = 2**16 + d2
                            d3 = data[ii,2] - 2**15
                            if (d3<0): 
                                d3 = 2**16 + d3
                            d4 = data[ii,3] - 2**15
                            if (d4<0): 
                                d4 = 2**16 + d4
                        else:
                            d2 = data[ii,1] 
                            d3 = data[ii,2]
                            d4 = data[ii,3] 
                            
                            
                        timestamp_pps = (d1 << 18) + (d2 << 2) + (d3 >> 14)
                        timestamp_ns  = ((d3 & (2**14-1))<<16) | d4
                    else:

                        
                        d1 = np.int64(data[ii,0]) & 127
                        if (addoffset>0):
                            d2 = np.int64(data[ii,1]) - 2**15 
                            if (d2<0): 
                                d2 = 2**16 + d2
                            d3 = np.int64(data[ii,2]) - 2**15
                            if (d3<0): 
                                d3 = 2**16 + d3
                            d4 = np.int64(data[ii,3]) - 2**15
                            if (d4<0): 
                                d4 = 2**16 + d4
                        else:
                            d2 = np.int64(data[ii,1]) 
                            d3 = np.int64(data[ii,2])
                            d4 = np.int64(data[ii,3])
                            
                        timestamp_pps = 0
                        timestamp_ns = np.int64( (d1 << 48) | (d2 << 32) | (d3 << 16) | d4)
                    
                    
                    numsamples = ii - last_ts_index - 1 # number of samples that are covered by this timestamp
                    
                    # Save the pps and ns timestamps and all data so far. The ns timestamp will be corrected later
                    for tt in range(ii-numsamples,ii):
                        Stokes0.append(np.uint16(data[tt,0]))
                        Stokes1_raw.append(data[tt,1])
                        Stokes2_raw.append(data[tt,2])
                        Stokes3_raw.append(data[tt,3])
                            
                        ts_diff_pps.append(timestamp_pps)
                        ts_diff_ns.append(np.int64(timestamp_ns-(ii-tt)*10))

                        currind = currind + 1
                            
                    last_ts_index = ii
                    ii = ii + 1
                    
                    if (ii<data.shape[0]): # still data available
                        if data[ii,0]>=65280: # another timestamp: pps correction      
                            
                            if (addoffset>0):
                                ts_ppskorr_up = data[ii,2]-2**15
                                ts_ppskorr_lo = data[ii,3]-2**15
                            
                                if (ts_ppskorr_up<0):
                                    ts_ppskorr_up = ts_ppskorr_up + 2**16
                                if (ts_ppskorr_lo<0):
                                    ts_ppskorr_lo = ts_ppskorr_lo + 2**16
                            else:
                                ts_ppskorr_up = data[ii,2]
                                ts_ppskorr_lo = data[ii,3]
                           
                            ppskorr = (ts_ppskorr_up % 2**14) * 2**16 + ts_ppskorr_lo + 10
                            ppskorrfact = 1e9 / ppskorr
                            print('1-ppskorrfact: ', 1-ppskorrfact)
                            ppskorrfactlist.append(1-ppskorrfact)
                            
                            # correct ns counter of timestamps
                            for tt in range(last_tscorr_index+1,currind-1):
                                ts_diff_ns[tt] = np.uint64(ts_diff_ns[tt] * ppskorrfact)
                                
                            
                            last_tscorr_index = currind

                            last_ts_index = ii
                            ii = ii + 1
                             
            # the timestamps after the last pps timestamps are corrected using the last known pps timestamp
            for tt in range(last_tscorr_index+1, len(ts_diff_ns)):
                ts_diff_ns[tt] = np.uint64(ts_diff_ns[tt] * ppskorrfact)
                
            
            # create the time axis x using ts_diff_pps and the corrected ts_diff_ns
            Timestamps = []
            #Timestamps_s_orig = []
            for ii in range(0, len(ts_diff_pps)):
                Timestamps.append(np.uint64(ts_diff_pps[ii]*1e9 + ts_diff_ns[ii]))
                #Timestamps_s_orig.append(ts_diff_pps[ii] + ts_diff_ns_orig[ii]*1e-9)
                
            
            #print(ppskorrfactlist)
            if (pps_sync_ext>0):
                if (len(ppskorrfactlist)==0):
                    print("Warning: No PPS timestamps found in data!")
                else:
                    print(len(ppskorrfactlist), 'PPS timestamps found (max. abs =', max(ppskorrfactlist, key=abs),').')

            
        else:
            print('Error: No timestamps found in data')


        Stokes1 = np.array(Stokes1_raw).astype(np.int16)
        
        if (addoffset>0):
            
            Stokes1 = np.int16(np.array(Stokes1_raw).astype(np.int32)-2**15)
            Stokes2 = np.int16(np.array(Stokes2_raw).astype(np.int32)-2**15)
            Stokes3 = np.int16(np.array(Stokes3_raw).astype(np.int32)-2**15)
  
        else:
            Stokes1 = np.array(Stokes1_raw).astype(np.int16)
            Stokes2 = np.array(Stokes2_raw).astype(np.int16)
            Stokes3 = np.array(Stokes3_raw).astype(np.int16)


        return np.array(Timestamps), np.array(Stokes0), Stokes1, Stokes2, Stokes3
    
    
    ###################################    
    # SDRAM recording & data transfer #
    ###################################
    
    def setSOPstreamperiod(self, sr: int): # in microseconds, max. 2**15-1
        self.n.write(512+109, sr)
        
    def getSOPstreamperiod(self): # in microseconds, max. 2**15-1
        return self.n.read(512+109)
        
    
    def setSOPstream(self, enable: int): # 0: disable; 1: enable
        self.n.setsopstream(enable)
        
    def getSOPstream(self, norm: int):
        return self.n.getsopstream(norm)
    
    ###################################    
    # SDRAM recording & data transfer #
    ###################################
    
    def pclear(self): # clears poincare sphere on hdmi output
        self.n.write(512+189, 0)
        
    def setME(self, ME: int): # sets Memory Exponent
        self.n.write(512+73, ME)
    
    def getME(self):
        return (self.n.read(512+73) & 0xFF)
    
    def getMaxME(self):
        MaxME = (self.n.read(512+73) >> 8)
        if (MaxME<26) or (MaxME>27):
            MaxME = 26
        return MaxME
    
    def isBusy(self):
        return self.n.read(512 + 72) & 1
    
    def isWaiting(self):
        return self.n.read(512 + 72) >> 15
    
    def waitBusy(self):
        busy = 1
        while busy>0:
            busy = self.isBusy()
            #print("busy...")
    
            
    def getrecaddress(self):
        addr = np.uint32(self.n.read(512+76))
        addr = addr + (np.uint32(self.n.read(512+77)) * 2**16)
        return addr
    
    def getrecblock(self):
        return self.n.read(512+78)
        
    def start_recording(self,
                        pps_sync_ext: int = 0,
                        dyn_undersampling: int = 0,
                        trigonnextbnc: int = 0,
                        syncextfalling: int = 0,
                        syncextrising: int = 0,
                        rearm_dno: int = 0,
                        rearm: int = 0,
                        cyclic: int = 0):
    
        if trigonnextbnc == 1:
            trigonnextbnc_n = 0
        else:
            trigonnextbnc_n = 1
            
        val = (pps_sync_ext * 2**9
               + dyn_undersampling * 2**8
               + trigonnextbnc * 2**7
               + syncextfalling * 2**6
               + syncextrising * 2**5
               + rearm_dno * 2**4
               + rearm * 2**3
               + cyclic * 2**2
               + trigonnextbnc_n)
        
        # Stop any ongoing recording in progress:
        if self.isBusy()>0:
            self.n.write(512+72, 0)
        
        print("REC started with code {0:016b} (%d)".format(val) %val)
        self.n.write(512+72, val)
    

    
    def getsdram(self, startaddr: int, numaddr: int, normalization=-1):
        data0, data1, data2, data3 = self.n.readsdram(startaddr, numaddr, normalization)
        return data0, data1, data2, data3
        
    
    def getsdram_raw(self, startaddr: int, numaddr: int):
        data0, data1, data2, data3 = self.n.readsdram(startaddr, numaddr, normalization=-1)
        data = np.array([data0, data1, data2, data3]).transpose()
        return data
    


    def startspeedhist(self, HistVexp=3, HistClkDiv=7,  HistTau=6):
         self.n.write(512 + 234, HistVexp)
         self.n.write(512 + 236, HistClkDiv * 2 ** 6 + HistTau - 1)
         self.n.write(512 + 232, 2**0)
        
    def stopspeedhist(self):
        self.n.write(512 + 232, 2**1)
        
    def getspeedhist(self):
        return self.n.gethistogram(histtype=0)
    
    def startpowhist(self, typesel=0, HistVexp=3, HistClkDiv=7,  HistTau=6):
         # typesel 0: histogram of the overall power
         # typesel 1: histogram of power deviation per time
        self.n.write(512 + 237, HistVexp * 2**3 + typesel)
        self.n.write(512 + 238, HistClkDiv * 2 ** 6 + HistTau - 1)
        self.n.write(512 + 232, 2**2)

    def stoppowhist(self):
        self.n.write(512 + 232, 2**3)

    def getpowhist(self):
        return self.n.gethistogram(histtype=1)
    
    

    def setppssync(self, enable: int, config=-1):
         # requires firmware >= 1.1.0.0 and PM1000 main board with AMD Artix 7 FPGA
         # arg 1: 1 (enable) or 0 (disable)
         # arg 2: tbd
         if (config<0): # don't modify configuration, 
             conf = self.n.read(512+211)
             if (enable==0):
                 self.n.write(512+211, conf & 0xFFFE)
             else:
                 self.n.write(512+211, conf | 1)
         else:
            self.n.write(512+211, conf * 2**1 + enable)


       
    def getppssyncstatus(self):
       
         conf = self.n.read(512+211) & 0x1FF 
         
         self.n.write(512+211, 0*2**9 + conf)
         err = self.n.read(512+212)
         print("error code =", err)
         
         self.n.write(512+211, 1*2**9 + conf)
         shift_period = self.n.read(512+212) # in ns*10
         
         if (shift_period>0):
             if (shift_period>=2**15): # incdec=1
                 shift_period = 2**15-shift_period

             print("ps_period  =", shift_period)
             shift_period_ns = 10 * shift_period
             shift_per_second = 1e9 / shift_period_ns
             clock_corr_ns = round(shift_per_second * 1 / 56)
             print("clock_corr = %0.2f us/s" % (clock_corr_ns/1000))

         
         self.n.write(512+211, 2*2**9 + conf)
         deviation = self.n.read(512+212)
         #print("deviation reg =", deviation)
         if (deviation>=2**15): # signed
             deviation=deviation-2**16
         deviation_ns = 20*deviation
         print("deviation  =", deviation_ns, "ns")
         
         self.n.write(512+211, 3*2**9 + conf)
         acc_delta = self.n.read(512+212)
         #print("acc_delta reg =", acc_delta)
         if (acc_delta>=2**15): # signed
             acc_delta=acc_delta-2**16
         print("acc_delta  = %0.2f us" % (acc_delta/100))


