wasm_server.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. '''
  2. /* Copyright (C) 2019 Intel Corporation. All rights reserved.
  3. * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. */
  5. '''
  6. import select
  7. import socket
  8. import queue
  9. from time import sleep
  10. import struct
  11. import threading
  12. import time
  13. from ctypes import *
  14. import json
  15. import logging
  16. import os
  17. attr_type_list = [
  18. "ATTR_NONE",
  19. "ATTR_TYPE_SHORT",
  20. "ATTR_TYPE_INT",
  21. "ATTR_TYPE_INT64",
  22. "ATTR_TYPE_BYTE",
  23. "ATTR_TYPE_UINT16",
  24. "ATTR_TYPE_FLOAT",
  25. "ATTR_TYPE_DOUBLE",
  26. "ATTR_TYPE_BOOLEAN",
  27. "ATTR_TYPE_STRING",
  28. "ATTR_TYPE_BYTEARRAY"
  29. ]
  30. Phase_Non_Start = 0
  31. Phase_Leading = 1
  32. Phase_Type = 2
  33. Phase_Size = 3
  34. Phase_Payload = 4
  35. class imrt_link_message(object):
  36. def __init__(self):
  37. self.leading = bytes([0x12, 0x34])
  38. self.phase = Phase_Non_Start
  39. self.size_in_phase = 0
  40. self.message_type = bytes()
  41. self.message_size = bytes()
  42. self.payload = bytes()
  43. self.msg = bytes()
  44. def set_recv_phase(self, phase):
  45. self.phase = phase
  46. def on_imrt_link_byte_arrive(self, ch):
  47. self.msg += ch
  48. if self.phase == Phase_Non_Start:
  49. if ch == b'\x12':
  50. self.set_recv_phase(Phase_Leading)
  51. else:
  52. return -1
  53. elif self.phase == Phase_Leading:
  54. if ch == b'\x34':
  55. self.set_recv_phase(Phase_Type)
  56. else:
  57. self.set_recv_phase(Phase_Non_Start)
  58. return -1
  59. elif self.phase == Phase_Type:
  60. self.message_type += ch
  61. self.size_in_phase += 1
  62. if self.size_in_phase == 2:
  63. (self.message_type, ) = struct.unpack('!H', self.message_type)
  64. self.size_in_phase = 0
  65. self.set_recv_phase(Phase_Size)
  66. elif self.phase == Phase_Size:
  67. self.message_size += ch
  68. self.size_in_phase += 1
  69. if self.size_in_phase == 4:
  70. (self.message_size, ) = struct.unpack('!I', self.message_size)
  71. self.size_in_phase = 0
  72. self.set_recv_phase(Phase_Payload)
  73. if self.message_size == b'\x00':
  74. self.set_recv_phase(Phase_Non_Start)
  75. return 0
  76. self.set_recv_phase(Phase_Payload)
  77. elif self.phase == Phase_Payload:
  78. self.payload += ch
  79. self.size_in_phase += 1
  80. if self.size_in_phase == self.message_size:
  81. self.set_recv_phase(Phase_Non_Start)
  82. return 0
  83. return 2
  84. return 1
  85. def read_file_to_buffer(file_name):
  86. file_object = open(file_name, 'rb')
  87. buffer = None
  88. if not os.path.exists(file_name):
  89. logging.error("file {} not found.".format(file_name))
  90. return "file not found"
  91. try:
  92. buffer = file_object.read()
  93. finally:
  94. file_object.close()
  95. return buffer
  96. def decode_attr_container(msg):
  97. attr_dict = {}
  98. buf = msg[26 : ]
  99. (total_len, tag_len) = struct.unpack('@IH', buf[0 : 6])
  100. tag_name = buf[6 : 6 + tag_len].decode()
  101. buf = buf[6 + tag_len : ]
  102. (attr_num, ) = struct.unpack('@H', buf[0 : 2])
  103. buf = buf[2 : ]
  104. logging.info("parsed attr:")
  105. logging.info("total_len:{}, tag_len:{}, tag_name:{}, attr_num:{}"
  106. .format(str(total_len), str(tag_len), str(tag_name), str(attr_num)))
  107. for i in range(attr_num):
  108. (key_len, ) = struct.unpack('@H', buf[0 : 2])
  109. key_name = buf[2 : 2 + key_len - 1].decode()
  110. buf = buf[2 + key_len : ]
  111. (type_index, ) = struct.unpack('@c', buf[0 : 1])
  112. attr_type = attr_type_list[int(type_index[0])]
  113. buf = buf[1 : ]
  114. if attr_type == "ATTR_TYPE_SHORT":
  115. (attr_value, ) = struct.unpack('@h', buf[0 : 2])
  116. buf = buf[2 : ]
  117. # continue
  118. elif attr_type == "ATTR_TYPE_INT":
  119. (attr_value, ) = struct.unpack('@I', buf[0 : 4])
  120. buf = buf[4 : ]
  121. # continue
  122. elif attr_type == "ATTR_TYPE_INT64":
  123. (attr_value, ) = struct.unpack('@q', buf[0 : 8])
  124. buf = buf[8 : ]
  125. # continue
  126. elif attr_type == "ATTR_TYPE_BYTE":
  127. (attr_value, ) = struct.unpack('@c', buf[0 : 1])
  128. buf = buf[1 : ]
  129. # continue
  130. elif attr_type == "ATTR_TYPE_UINT16":
  131. (attr_value, ) = struct.unpack('@H', buf[0 : 2])
  132. buf = buf[2 : ]
  133. # continue
  134. elif attr_type == "ATTR_TYPE_FLOAT":
  135. (attr_value, ) = struct.unpack('@f', buf[0 : 4])
  136. buf = buf[4 : ]
  137. # continue
  138. elif attr_type == "ATTR_TYPE_DOUBLE":
  139. (attr_value, ) = struct.unpack('@d', buf[0 : 8])
  140. buf = buf[8 : ]
  141. # continue
  142. elif attr_type == "ATTR_TYPE_BOOLEAN":
  143. (attr_value, ) = struct.unpack('@?', buf[0 : 1])
  144. buf = buf[1 : ]
  145. # continue
  146. elif attr_type == "ATTR_TYPE_STRING":
  147. (str_len, ) = struct.unpack('@H', buf[0 : 2])
  148. attr_value = buf[2 : 2 + str_len - 1].decode()
  149. buf = buf[2 + str_len : ]
  150. # continue
  151. elif attr_type == "ATTR_TYPE_BYTEARRAY":
  152. (byte_len, ) = struct.unpack('@I', buf[0 : 4])
  153. attr_value = buf[4 : 4 + byte_len]
  154. buf = buf[4 + byte_len : ]
  155. # continue
  156. attr_dict[key_name] = attr_value
  157. logging.info(str(attr_dict))
  158. return attr_dict
  159. class Request():
  160. mid = 0
  161. url = ""
  162. action = 0
  163. fmt = 0
  164. payload = ""
  165. payload_len = 0
  166. sender = 0
  167. def __init__(self, url, action, fmt, payload, payload_len):
  168. self.url = url
  169. self.action = action
  170. self.fmt = fmt
  171. # if type(payload) == bytes:
  172. # self.payload = bytes(payload, encoding = "utf8")
  173. # else:
  174. self.payload_len = payload_len
  175. if self.payload_len > 0:
  176. self.payload = payload
  177. def pack_request(self):
  178. url_len = len(self.url) + 1
  179. buffer_len = url_len + self.payload_len
  180. req_buffer = struct.pack('!2BH2IHI',1, self.action, self.fmt, self.mid, self.sender, url_len, self.payload_len)
  181. for i in range(url_len - 1):
  182. req_buffer += struct.pack('!c', bytes(self.url[i], encoding = "utf8"))
  183. req_buffer += bytes([0])
  184. for i in range(self.payload_len):
  185. req_buffer += struct.pack('!B', self.payload[i])
  186. return req_buffer, len(req_buffer)
  187. def send(self, conn, is_install):
  188. leading = struct.pack('!2B', 0x12, 0x34)
  189. if not is_install:
  190. msg_type = struct.pack('!H', 0x0002)
  191. else:
  192. msg_type = struct.pack('!H', 0x0004)
  193. buff, buff_len = self.pack_request()
  194. lenth = struct.pack('!I', buff_len)
  195. try:
  196. conn.send(leading)
  197. conn.send(msg_type)
  198. conn.send(lenth)
  199. conn.send(buff)
  200. except socket.error as e:
  201. logging.error("device closed")
  202. for dev in tcpserver.devices:
  203. if dev.conn == conn:
  204. tcpserver.devices.remove(dev)
  205. return -1
  206. def query(conn):
  207. req = Request("/applet", 1, 0, "", 0)
  208. if req.send(conn, False) == -1:
  209. return "fail"
  210. time.sleep(0.05)
  211. try:
  212. receive_context = imrt_link_message()
  213. start = time.time()
  214. while True:
  215. if receive_context.on_imrt_link_byte_arrive(conn.recv(1)) == 0:
  216. break
  217. elif time.time() - start >= 5.0:
  218. return "fail"
  219. query_resp = receive_context.msg
  220. print(query_resp)
  221. except OSError as e:
  222. logging.error("OSError exception occur")
  223. return "fail"
  224. res = decode_attr_container(query_resp)
  225. logging.info('Query device infomation success')
  226. return res
  227. def install(conn, app_name, wasm_file):
  228. wasm = read_file_to_buffer(wasm_file)
  229. if wasm == "file not found":
  230. return "failed to install: file not found"
  231. print("wasm file len:")
  232. print(len(wasm))
  233. req = Request("/applet?name=" + app_name, 3, 98, wasm, len(wasm))
  234. if req.send(conn, True) == -1:
  235. return "fail"
  236. time.sleep(0.05)
  237. try:
  238. receive_context = imrt_link_message()
  239. start = time.time()
  240. while True:
  241. if receive_context.on_imrt_link_byte_arrive(conn.recv(1)) == 0:
  242. break
  243. elif time.time() - start >= 5.0:
  244. return "fail"
  245. msg = receive_context.msg
  246. except OSError as e:
  247. logging.error("OSError exception occur")
  248. # TODO: check return message
  249. if len(msg) == 24 and msg[8 + 1] == 65:
  250. logging.info('Install application success')
  251. return "success"
  252. else:
  253. res = decode_attr_container(msg)
  254. logging.warning('Install application failed: %s' % (str(res)))
  255. print(str(res))
  256. return str(res)
  257. def uninstall(conn, app_name):
  258. req = Request("/applet?name=" + app_name, 4, 99, "", 0)
  259. if req.send(conn, False) == -1:
  260. return "fail"
  261. time.sleep(0.05)
  262. try:
  263. receive_context = imrt_link_message()
  264. start = time.time()
  265. while True:
  266. if receive_context.on_imrt_link_byte_arrive(conn.recv(1)) == 0:
  267. break
  268. elif time.time() - start >= 5.0:
  269. return "fail"
  270. msg = receive_context.msg
  271. except OSError as e:
  272. logging.error("OSError exception occur")
  273. # TODO: check return message
  274. if len(msg) == 24 and msg[8 + 1] == 66:
  275. logging.info('Uninstall application success')
  276. return "success"
  277. else:
  278. res = decode_attr_container(msg)
  279. logging.warning('Uninstall application failed: %s' % (str(res)))
  280. print(str(res))
  281. return str(res)
  282. class Device:
  283. def __init__(self, conn, addr, port):
  284. self.conn = conn
  285. self.addr = addr
  286. self.port = port
  287. self.app_num = 0
  288. self.apps = []
  289. cmd = []
  290. class TCPServer:
  291. def __init__(self, server, server_address, inputs, outputs, message_queues):
  292. # Create a TCP/IP
  293. self.server = server
  294. self.server.setblocking(False)
  295. # Bind the socket to the port
  296. self.server_address = server_address
  297. print('starting up on %s port %s' % self.server_address)
  298. self.server.bind(self.server_address)
  299. # Listen for incoming connections
  300. self.server.listen(10)
  301. self.cmd_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  302. self.cmd_sock.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)
  303. self.cmd_sock.bind(('127.0.0.1', 8889))
  304. self.cmd_sock.listen(5)
  305. # Sockets from which we expect to read
  306. self.inputs = inputs
  307. self.inputs.append(self.cmd_sock)
  308. # Sockets to which we expect to write
  309. # 处理要发送的消息
  310. self.outputs = outputs
  311. # Outgoing message queues (socket: Queue)
  312. self.message_queues = message_queues
  313. self.devices = []
  314. self.conn_dict = {}
  315. def handler_recever(self, readable):
  316. # Handle inputs
  317. for s in readable:
  318. if s is self.server:
  319. # A "readable" socket is ready to accept a connection
  320. connection, client_address = s.accept()
  321. self.client_address = client_address
  322. print('connection from', client_address)
  323. # this is connection not server
  324. # connection.setblocking(0)
  325. self.inputs.append(connection)
  326. # Give the connection a queue for data we want to send
  327. # self.message_queues[connection] = queue.Queue()
  328. res = query(connection)
  329. if res != "fail":
  330. dev = Device(connection, client_address[0], client_address[1])
  331. self.devices.append(dev)
  332. self.conn_dict[client_address] = connection
  333. dev_info = {}
  334. dev_info['addr'] = dev.addr
  335. dev_info['port'] = dev.port
  336. dev_info['apps'] = 0
  337. logging.info('A new client connected from ("%s":"%s")' % (dev.conn, dev.port))
  338. elif s is self.cmd_sock:
  339. connection, client_address = s.accept()
  340. print("web server socket connected")
  341. logging.info("Django server connected")
  342. self.inputs.append(connection)
  343. self.message_queues[connection] = queue.Queue()
  344. else:
  345. data = s.recv(1024)
  346. if data != b'':
  347. # A readable client socket has data
  348. logging.info('received "%s" from %s' % (data, s.getpeername()))
  349. # self.message_queues[s].put(data)
  350. # # Add output channel for response
  351. # if s not in self.outputs:
  352. # self.outputs.append(s)
  353. if(data.decode().split(':')[0] == "query"):
  354. if data.decode().split(':')[1] == "all":
  355. resp = []
  356. print('start query all devices')
  357. for dev in self.devices:
  358. dev_info = query(dev.conn)
  359. if dev_info == "fail":
  360. continue
  361. dev_info["addr"] = dev.addr
  362. dev_info["port"] = dev.port
  363. resp.append(str(dev_info))
  364. print(resp)
  365. if self.message_queues[s] is not None:
  366. # '*' is used in web server to sperate the string
  367. self.message_queues[s].put(bytes("*".join(resp), encoding = 'utf8'))
  368. if s not in self.outputs:
  369. self.outputs.append(s)
  370. else:
  371. client_addr = (data.decode().split(':')[1],int(data.decode().split(':')[2]))
  372. if client_addr in self.conn_dict.keys():
  373. print('start query device from (%s:%s)' % (client_addr[0], client_addr[1]))
  374. resp = query(self.conn_dict[client_addr])
  375. print(resp)
  376. if self.message_queues[s] is not None:
  377. self.message_queues[s].put(bytes(str(resp), encoding = 'utf8'))
  378. if s not in self.outputs:
  379. self.outputs.append(s)
  380. else: # no connection
  381. if self.message_queues[s] is not None:
  382. self.message_queues[s].put(bytes(str("fail"), encoding = 'utf8'))
  383. if s not in self.outputs:
  384. self.outputs.append(s)
  385. elif(data.decode().split(':')[0] == "install"):
  386. client_addr = (data.decode().split(':')[1],int(data.decode().split(':')[2]))
  387. app_name = data.decode().split(':')[3]
  388. app_file = data.decode().split(':')[4]
  389. if client_addr in self.conn_dict.keys():
  390. print('start install application %s to ("%s":"%s")' % (app_name, client_addr[0], client_addr[1]))
  391. res = install(self.conn_dict[client_addr], app_name, app_file)
  392. if self.message_queues[s] is not None:
  393. logging.info("response {} to cmd server".format(res))
  394. self.message_queues[s].put(bytes(res, encoding = 'utf8'))
  395. if s not in self.outputs:
  396. self.outputs.append(s)
  397. elif(data.decode().split(':')[0] == "uninstall"):
  398. client_addr = (data.decode().split(':')[1],int(data.decode().split(':')[2]))
  399. app_name = data.decode().split(':')[3]
  400. if client_addr in self.conn_dict.keys():
  401. print("start uninstall")
  402. res = uninstall(self.conn_dict[client_addr], app_name)
  403. if self.message_queues[s] is not None:
  404. logging.info("response {} to cmd server".format(res))
  405. self.message_queues[s].put(bytes(res, encoding = 'utf8'))
  406. if s not in self.outputs:
  407. self.outputs.append(s)
  408. # if self.message_queues[s] is not None:
  409. # self.message_queues[s].put(data)
  410. # if s not in self.outputs:
  411. # self.outputs.append(s)
  412. else:
  413. logging.warning(data)
  414. # Interpret empty result as closed connection
  415. try:
  416. for dev in self.devices:
  417. if s == dev.conn:
  418. self.devices.remove(dev)
  419. # Stop listening for input on the connection
  420. if s in self.outputs:
  421. self.outputs.remove(s)
  422. self.inputs.remove(s)
  423. # Remove message queue
  424. if s in self.message_queues.keys():
  425. del self.message_queues[s]
  426. s.close()
  427. except OSError as e:
  428. logging.error("OSError raised, unknown connection")
  429. return "got it"
  430. def handler_send(self, writable):
  431. # Handle outputs
  432. for s in writable:
  433. try:
  434. message_queue = self.message_queues.get(s)
  435. send_data = ''
  436. if message_queue is not None:
  437. send_data = message_queue.get_nowait()
  438. except queue.Empty:
  439. self.outputs.remove(s)
  440. else:
  441. # print "sending %s to %s " % (send_data, s.getpeername)
  442. # print "send something"
  443. if message_queue is not None:
  444. s.send(send_data)
  445. else:
  446. print("client has closed")
  447. # del message_queues[s]
  448. # writable.remove(s)
  449. # print "Client %s disconnected" % (client_address)
  450. return "got it"
  451. def handler_exception(self, exceptional):
  452. # # Handle "exceptional conditions"
  453. for s in exceptional:
  454. print('exception condition on', s.getpeername())
  455. # Stop listening for input on the connection
  456. self.inputs.remove(s)
  457. if s in self.outputs:
  458. self.outputs.remove(s)
  459. s.close()
  460. # Remove message queue
  461. del self.message_queues[s]
  462. return "got it"
  463. def event_loop(tcpserver, inputs, outputs):
  464. while inputs:
  465. # Wait for at least one of the sockets to be ready for processing
  466. print('waiting for the next event')
  467. readable, writable, exceptional = select.select(inputs, outputs, inputs)
  468. if readable is not None:
  469. tcp_recever = tcpserver.handler_recever(readable)
  470. if tcp_recever == 'got it':
  471. print("server have received")
  472. if writable is not None:
  473. tcp_send = tcpserver.handler_send(writable)
  474. if tcp_send == 'got it':
  475. print("server have send")
  476. if exceptional is not None:
  477. tcp_exception = tcpserver.handler_exception(exceptional)
  478. if tcp_exception == 'got it':
  479. print("server have exception")
  480. sleep(0.1)
  481. def run_wasm_server():
  482. server_address = ('localhost', 8888)
  483. server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  484. server.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)
  485. inputs = [server]
  486. outputs = []
  487. message_queues = {}
  488. tcpserver = TCPServer(server, server_address, inputs, outputs, message_queues)
  489. task = threading.Thread(target=event_loop,args=(tcpserver,inputs,outputs))
  490. task.start()
  491. if __name__ == '__main__':
  492. logging.basicConfig(level=logging.DEBUG,
  493. filename='wasm_server.log',
  494. filemode='a',
  495. format=
  496. '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
  497. )
  498. server_address = ('0.0.0.0', 8888)
  499. server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  500. server.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)
  501. inputs = [server]
  502. outputs = []
  503. message_queues = {}
  504. tcpserver = TCPServer(server, server_address, inputs, outputs, message_queues)
  505. logging.info("TCP Server start at {}:{}".format(server_address[0], "8888"))
  506. task = threading.Thread(target=event_loop,args=(tcpserver,inputs,outputs))
  507. task.start()
  508. # event_loop(tcpserver, inputs, outputs)