"""
The MIT License (MIT)
Copyright © 2020 Walkline Wang (https://walkline.wang)
https://gitee.com/walkline/esp32-ble
"""
from machine import Pin, Timer
from utime import sleep_ms
from micropython import alloc_emergency_exception_buf

alloc_emergency_exception_buf(100)


class  KeyPadException(Exception):
	pass


class KeyPad(object):
	"""
	矩阵键盘驱动

	作为 HID 输入设备需要实现:
		1. 按键按下:触发 key_down
		2. 按键松开:触发 key_up
	"""
	MATRIX_SCAN_PERIOD = 50 # 键盘矩阵扫描间隔

	def __init__(self, pin_set=None, key_down_cb=None, key_up_cb=None):
		assert pin_set is not None and isinstance(pin_set, tuple) and len(pin_set) == 2, KeyPadException("pin_set must be a tuple, e.g. ((row output io), (column input io))")

		self.__output_io_set = [] # matrix row
		self.__input_io_set = [] # matrix column

		try:
			for io in pin_set[0]:
				self.__output_io_set.append(Pin(io, Pin.OUT, value=0))
			
			for io in pin_set[1]:
				self.__input_io_set.append(Pin(io, Pin.IN, Pin.PULL_DOWN))
		except IndexError:
			raise KeyPadException("pin_set value error")

		self.__row_count = len(self.__output_io_set)
		self.__column_count = len(self.__input_io_set)

		assert self.__row_count > 0 and self.__column_count > 0, KeyPadException("pin_set value error")

		self.__timer = Timer(11)
		self.__key_down_cb = key_down_cb
		self.__key_up_cb = key_up_cb
		self.__last_key_status_table = [0b1 << self.__column_count for row in self.__output_io_set]

	def get_key_count(self):
		"""
		获取按键数量最大值
		"""
		return self.__row_count * self.__column_count

	def capture(self):
		"""
		开启按键捕捉扫描
		"""
		self.__timer.init(
			mode=Timer.PERIODIC,
			period=self.MATRIX_SCAN_PERIOD,
			callback=self.__matrix_scan
		)

	def __matrix_scan(self, timer):
		current_key_status_table = []

		for index in range(self.__row_count):
			for io in self.__output_io_set:
				io.off()
			
			self.__output_io_set[index].on()

			row_status = 0b1 # 用二进制存储每行按键的状态,初值 0b1 为了确保位数正确
			for io in self.__input_io_set:
				row_status <<= 1 # 当前值左移 1 位
				row_status += io.value() # 再加上当前按键状态值

			current_key_status_table.append(row_status)
		
		if self.__last_key_status_table != current_key_status_table:
			# self.__print_status_table(current_key_status_table)

			for index_row in range(self.__row_count):
				if current_key_status_table[index_row] != self.__last_key_status_table[index_row]:
					current_column = bin(current_key_status_table[index_row])[3:]
					last_column = bin(self.__last_key_status_table[index_row])[3:]

					for index_column in range(self.__column_count):
						if current_column[index_column] != last_column[index_column]:
							if current_column[index_column] == '1':
								self.__trigger_key_down_cb(index_row, index_column)
							else:
								self.__trigger_key_up_cb(index_row, index_column)

			self.__last_key_status_table = current_key_status_table

	def __trigger_key_down_cb(self, row, column):
		if self.__key_down_cb is not None:
			self.__key_down_cb(row, column)

	def __trigger_key_up_cb(self, row, column):
		if self.__key_up_cb is not None:
			self.__key_up_cb(row, column)

	def __print_status_table(self, table):
		"""
		打印按键状态表
		"""
		for index in range(len(table)):
			print("row {}: {}".format(index, [bit for bit in bin(table[index])[3:]]))

	
def main():
	def key_down_cb(row, column):
		print("key ({}, {}) down".format(row, column))

	def key_up_cb(row, column):
		print("key ({}, {}) up".format(row, column))

	ROW_SET = (14, 12, 13, 23, 22) # for output
	COLUMN_SET = (32, 33, 25, 26, 27) # for intput
	PIN_SET = (ROW_SET, COLUMN_SET)

	keypad = KeyPad(PIN_SET, key_down_cb=key_down_cb, key_up_cb=key_up_cb)
	keypad.capture()

	print("Keypad max key count:", keypad.get_key_count())


if __name__ == "__main__":
	try:
		main()
	except KeyboardInterrupt:
		print("\nPRESS CTRL+D TO RESET DEVICE")