wasm_server.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. '''
  2. /* Copyright (C) 2019 Intel Corporation. All rights reserved.
  3. * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. */
  5. '''
  6. import select
  7. import socket
  8. import queue
  9. from time import sleep
  10. import struct
  11. import threading
  12. import time
  13. from ctypes import *
  14. import json
  15. import logging
  16. import os
  17. attr_type_list = [
  18. "ATTR_TYPE_BYTE", # = ATTR_TYPE_INT8
  19. "ATTR_TYPE_SHORT",# = ATTR_TYPE_INT16
  20. "ATTR_TYPE_INT", # = ATTR_TYPE_INT32
  21. "ATTR_TYPE_INT64",
  22. "ATTR_TYPE_UINT8",
  23. "ATTR_TYPE_UINT16",
  24. "ATTR_TYPE_UINT32",
  25. "ATTR_TYPE_UINT64",
  26. "ATTR_TYPE_FLOAT",
  27. "ATTR_TYPE_DOUBLE",
  28. "ATTR_NONE",
  29. "ATTR_NONE",
  30. "ATTR_TYPE_BOOLEAN",
  31. "ATTR_TYPE_STRING",
  32. "ATTR_TYPE_BYTEARRAY"
  33. ]
  34. Phase_Non_Start = 0
  35. Phase_Leading = 1
  36. Phase_Type = 2
  37. Phase_Size = 3
  38. Phase_Payload = 4
  39. class imrt_link_message(object):
  40. def __init__(self):
  41. self.leading = bytes([0x12, 0x34])
  42. self.phase = Phase_Non_Start
  43. self.size_in_phase = 0
  44. self.message_type = bytes()
  45. self.message_size = bytes()
  46. self.payload = bytes()
  47. self.msg = bytes()
  48. def set_recv_phase(self, phase):
  49. self.phase = phase
  50. def on_imrt_link_byte_arrive(self, ch):
  51. self.msg += ch
  52. if self.phase == Phase_Non_Start:
  53. if ch == b'\x12':
  54. self.set_recv_phase(Phase_Leading)
  55. else:
  56. return -1
  57. elif self.phase == Phase_Leading:
  58. if ch == b'\x34':
  59. self.set_recv_phase(Phase_Type)
  60. else:
  61. self.set_recv_phase(Phase_Non_Start)
  62. return -1
  63. elif self.phase == Phase_Type:
  64. self.message_type += ch
  65. self.size_in_phase += 1
  66. if self.size_in_phase == 2:
  67. (self.message_type, ) = struct.unpack('!H', self.message_type)
  68. self.size_in_phase = 0
  69. self.set_recv_phase(Phase_Size)
  70. elif self.phase == Phase_Size:
  71. self.message_size += ch
  72. self.size_in_phase += 1
  73. if self.size_in_phase == 4:
  74. (self.message_size, ) = struct.unpack('!I', self.message_size)
  75. self.size_in_phase = 0
  76. self.set_recv_phase(Phase_Payload)
  77. if self.message_size == b'\x00':
  78. self.set_recv_phase(Phase_Non_Start)
  79. return 0
  80. self.set_recv_phase(Phase_Payload)
  81. elif self.phase == Phase_Payload:
  82. self.payload += ch
  83. self.size_in_phase += 1
  84. if self.size_in_phase == self.message_size:
  85. self.set_recv_phase(Phase_Non_Start)
  86. return 0
  87. return 2
  88. return 1
  89. def read_file_to_buffer(file_name):
  90. file_object = open(file_name, 'rb')
  91. buffer = None
  92. if not os.path.exists(file_name):
  93. logging.error("file {} not found.".format(file_name))
  94. return "file not found"
  95. try:
  96. buffer = file_object.read()
  97. finally:
  98. file_object.close()
  99. return buffer
  100. def decode_attr_container(msg):
  101. attr_dict = {}
  102. buf = msg[26 : ]
  103. (total_len, tag_len) = struct.unpack('@IH', buf[0 : 6])
  104. tag_name = buf[6 : 6 + tag_len].decode()
  105. buf = buf[6 + tag_len : ]
  106. (attr_num, ) = struct.unpack('@H', buf[0 : 2])
  107. buf = buf[2 : ]
  108. logging.info("parsed attr:")
  109. logging.info("total_len:{}, tag_len:{}, tag_name:{}, attr_num:{}"
  110. .format(str(total_len), str(tag_len), str(tag_name), str(attr_num)))
  111. for i in range(attr_num):
  112. (key_len, ) = struct.unpack('@H', buf[0 : 2])
  113. key_name = buf[2 : 2 + key_len - 1].decode()
  114. buf = buf[2 + key_len : ]
  115. (type_index, ) = struct.unpack('@c', buf[0 : 1])
  116. attr_type = attr_type_list[int(type_index[0])]
  117. buf = buf[1 : ]
  118. if attr_type == "ATTR_TYPE_BYTE": # = ATTR_TYPE_INT8
  119. (attr_value, ) = struct.unpack('@c', buf[0 : 1])
  120. buf = buf[1 : ]
  121. # continue
  122. elif attr_type == "ATTR_TYPE_SHORT": # = ATTR_TYPE_INT16
  123. (attr_value, ) = struct.unpack('@h', buf[0 : 2])
  124. buf = buf[2 : ]
  125. # continue
  126. elif attr_type == "ATTR_TYPE_INT": # = ATTR_TYPE_INT32
  127. (attr_value, ) = struct.unpack('@i', buf[0 : 4])
  128. buf = buf[4 : ]
  129. # continue
  130. elif attr_type == "ATTR_TYPE_INT64":
  131. (attr_value, ) = struct.unpack('@q', buf[0 : 8])
  132. buf = buf[8 : ]
  133. # continue
  134. elif attr_type == "ATTR_TYPE_UINT8":
  135. (attr_value, ) = struct.unpack('@B', buf[0 : 1])
  136. buf = buf[1 : ]
  137. # continue
  138. elif attr_type == "ATTR_TYPE_UINT16":
  139. (attr_value, ) = struct.unpack('@H', buf[0 : 2])
  140. buf = buf[2 : ]
  141. # continue
  142. elif attr_type == "ATTR_TYPE_UINT32":
  143. (attr_value, ) = struct.unpack('@I', buf[0 : 4])
  144. buf = buf[4 : ]
  145. # continue
  146. elif attr_type == "ATTR_TYPE_UINT64":
  147. (attr_value, ) = struct.unpack('@Q', buf[0 : 8])
  148. buf = buf[8 : ]
  149. # continue
  150. elif attr_type == "ATTR_TYPE_FLOAT":
  151. (attr_value, ) = struct.unpack('@f', buf[0 : 4])
  152. buf = buf[4 : ]
  153. # continue
  154. elif attr_type == "ATTR_TYPE_DOUBLE":
  155. (attr_value, ) = struct.unpack('@d', buf[0 : 8])
  156. buf = buf[8 : ]
  157. # continue
  158. elif attr_type == "ATTR_TYPE_BOOLEAN":
  159. (attr_value, ) = struct.unpack('@?', buf[0 : 1])
  160. buf = buf[1 : ]
  161. # continue
  162. elif attr_type == "ATTR_TYPE_STRING":
  163. (str_len, ) = struct.unpack('@H', buf[0 : 2])
  164. attr_value = buf[2 : 2 + str_len - 1].decode()
  165. buf = buf[2 + str_len : ]
  166. # continue
  167. elif attr_type == "ATTR_TYPE_BYTEARRAY":
  168. (byte_len, ) = struct.unpack('@I', buf[0 : 4])
  169. attr_value = buf[4 : 4 + byte_len]
  170. buf = buf[4 + byte_len : ]
  171. # continue
  172. attr_dict[key_name] = attr_value
  173. logging.info(str(attr_dict))
  174. return attr_dict
  175. class Request():
  176. mid = 0
  177. url = ""
  178. action = 0
  179. fmt = 0
  180. payload = ""
  181. payload_len = 0
  182. sender = 0
  183. def __init__(self, url, action, fmt, payload, payload_len):
  184. self.url = url
  185. self.action = action
  186. self.fmt = fmt
  187. # if type(payload) == bytes:
  188. # self.payload = bytes(payload, encoding = "utf8")
  189. # else:
  190. self.payload_len = payload_len
  191. if self.payload_len > 0:
  192. self.payload = payload
  193. def pack_request(self):
  194. url_len = len(self.url) + 1
  195. buffer_len = url_len + self.payload_len
  196. req_buffer = struct.pack('!2BH2IHI',1, self.action, self.fmt, self.mid, self.sender, url_len, self.payload_len)
  197. for i in range(url_len - 1):
  198. req_buffer += struct.pack('!c', bytes(self.url[i], encoding = "utf8"))
  199. req_buffer += bytes([0])
  200. for i in range(self.payload_len):
  201. req_buffer += struct.pack('!B', self.payload[i])
  202. return req_buffer, len(req_buffer)
  203. def send(self, conn, is_install):
  204. leading = struct.pack('!2B', 0x12, 0x34)
  205. if not is_install:
  206. msg_type = struct.pack('!H', 0x0002)
  207. else:
  208. msg_type = struct.pack('!H', 0x0004)
  209. buff, buff_len = self.pack_request()
  210. lenth = struct.pack('!I', buff_len)
  211. try:
  212. conn.send(leading)
  213. conn.send(msg_type)
  214. conn.send(lenth)
  215. conn.send(buff)
  216. except socket.error as e:
  217. logging.error("device closed")
  218. for dev in tcpserver.devices:
  219. if dev.conn == conn:
  220. tcpserver.devices.remove(dev)
  221. return -1
  222. def query(conn):
  223. req = Request("/applet", 1, 0, "", 0)
  224. if req.send(conn, False) == -1:
  225. return "fail"
  226. time.sleep(0.05)
  227. try:
  228. receive_context = imrt_link_message()
  229. start = time.time()
  230. while True:
  231. if receive_context.on_imrt_link_byte_arrive(conn.recv(1)) == 0:
  232. break
  233. elif time.time() - start >= 5.0:
  234. return "fail"
  235. query_resp = receive_context.msg
  236. print(query_resp)
  237. except OSError as e:
  238. logging.error("OSError exception occur")
  239. return "fail"
  240. res = decode_attr_container(query_resp)
  241. logging.info('Query device infomation success')
  242. return res
  243. def install(conn, app_name, wasm_file):
  244. wasm = read_file_to_buffer(wasm_file)
  245. if wasm == "file not found":
  246. return "failed to install: file not found"
  247. print("wasm file len:")
  248. print(len(wasm))
  249. req = Request("/applet?name=" + app_name, 3, 98, wasm, len(wasm))
  250. if req.send(conn, True) == -1:
  251. return "fail"
  252. time.sleep(0.05)
  253. try:
  254. receive_context = imrt_link_message()
  255. start = time.time()
  256. while True:
  257. if receive_context.on_imrt_link_byte_arrive(conn.recv(1)) == 0:
  258. break
  259. elif time.time() - start >= 5.0:
  260. return "fail"
  261. msg = receive_context.msg
  262. except OSError as e:
  263. logging.error("OSError exception occur")
  264. # TODO: check return message
  265. if len(msg) == 24 and msg[8 + 1] == 65:
  266. logging.info('Install application success')
  267. return "success"
  268. else:
  269. res = decode_attr_container(msg)
  270. logging.warning('Install application failed: %s' % (str(res)))
  271. print(str(res))
  272. return str(res)
  273. def uninstall(conn, app_name):
  274. req = Request("/applet?name=" + app_name, 4, 99, "", 0)
  275. if req.send(conn, False) == -1:
  276. return "fail"
  277. time.sleep(0.05)
  278. try:
  279. receive_context = imrt_link_message()
  280. start = time.time()
  281. while True:
  282. if receive_context.on_imrt_link_byte_arrive(conn.recv(1)) == 0:
  283. break
  284. elif time.time() - start >= 5.0:
  285. return "fail"
  286. msg = receive_context.msg
  287. except OSError as e:
  288. logging.error("OSError exception occur")
  289. # TODO: check return message
  290. if len(msg) == 24 and msg[8 + 1] == 66:
  291. logging.info('Uninstall application success')
  292. return "success"
  293. else:
  294. res = decode_attr_container(msg)
  295. logging.warning('Uninstall application failed: %s' % (str(res)))
  296. print(str(res))
  297. return str(res)
  298. class Device:
  299. def __init__(self, conn, addr, port):
  300. self.conn = conn
  301. self.addr = addr
  302. self.port = port
  303. self.app_num = 0
  304. self.apps = []
  305. cmd = []
  306. class TCPServer:
  307. def __init__(self, server, server_address, inputs, outputs, message_queues):
  308. # Create a TCP/IP
  309. self.server = server
  310. self.server.setblocking(False)
  311. # Bind the socket to the port
  312. self.server_address = server_address
  313. print('starting up on %s port %s' % self.server_address)
  314. self.server.bind(self.server_address)
  315. # Listen for incoming connections
  316. self.server.listen(10)
  317. self.cmd_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  318. self.cmd_sock.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)
  319. self.cmd_sock.bind(('127.0.0.1', 8889))
  320. self.cmd_sock.listen(5)
  321. # Sockets from which we expect to read
  322. self.inputs = inputs
  323. self.inputs.append(self.cmd_sock)
  324. # Sockets to which we expect to write
  325. # 处理要发送的消息
  326. self.outputs = outputs
  327. # Outgoing message queues (socket: Queue)
  328. self.message_queues = message_queues
  329. self.devices = []
  330. self.conn_dict = {}
  331. def handler_recever(self, readable):
  332. # Handle inputs
  333. for s in readable:
  334. if s is self.server:
  335. # A "readable" socket is ready to accept a connection
  336. connection, client_address = s.accept()
  337. self.client_address = client_address
  338. print('connection from', client_address)
  339. # this is connection not server
  340. # connection.setblocking(0)
  341. self.inputs.append(connection)
  342. # Give the connection a queue for data we want to send
  343. # self.message_queues[connection] = queue.Queue()
  344. res = query(connection)
  345. if res != "fail":
  346. dev = Device(connection, client_address[0], client_address[1])
  347. self.devices.append(dev)
  348. self.conn_dict[client_address] = connection
  349. dev_info = {}
  350. dev_info['addr'] = dev.addr
  351. dev_info['port'] = dev.port
  352. dev_info['apps'] = 0
  353. logging.info('A new client connected from ("%s":"%s")' % (dev.conn, dev.port))
  354. elif s is self.cmd_sock:
  355. connection, client_address = s.accept()
  356. print("web server socket connected")
  357. logging.info("Django server connected")
  358. self.inputs.append(connection)
  359. self.message_queues[connection] = queue.Queue()
  360. else:
  361. data = s.recv(1024)
  362. if data != b'':
  363. # A readable client socket has data
  364. logging.info('received "%s" from %s' % (data, s.getpeername()))
  365. # self.message_queues[s].put(data)
  366. # # Add output channel for response
  367. # if s not in self.outputs:
  368. # self.outputs.append(s)
  369. if(data.decode().split(':')[0] == "query"):
  370. if data.decode().split(':')[1] == "all":
  371. resp = []
  372. print('start query all devices')
  373. for dev in self.devices:
  374. dev_info = query(dev.conn)
  375. if dev_info == "fail":
  376. continue
  377. dev_info["addr"] = dev.addr
  378. dev_info["port"] = dev.port
  379. resp.append(str(dev_info))
  380. print(resp)
  381. if self.message_queues[s] is not None:
  382. # '*' is used in web server to sperate the string
  383. self.message_queues[s].put(bytes("*".join(resp), encoding = 'utf8'))
  384. if s not in self.outputs:
  385. self.outputs.append(s)
  386. else:
  387. client_addr = (data.decode().split(':')[1],int(data.decode().split(':')[2]))
  388. if client_addr in self.conn_dict.keys():
  389. print('start query device from (%s:%s)' % (client_addr[0], client_addr[1]))
  390. resp = query(self.conn_dict[client_addr])
  391. print(resp)
  392. if self.message_queues[s] is not None:
  393. self.message_queues[s].put(bytes(str(resp), encoding = 'utf8'))
  394. if s not in self.outputs:
  395. self.outputs.append(s)
  396. else: # no connection
  397. if self.message_queues[s] is not None:
  398. self.message_queues[s].put(bytes(str("fail"), encoding = 'utf8'))
  399. if s not in self.outputs:
  400. self.outputs.append(s)
  401. elif(data.decode().split(':')[0] == "install"):
  402. client_addr = (data.decode().split(':')[1],int(data.decode().split(':')[2]))
  403. app_name = data.decode().split(':')[3]
  404. app_file = data.decode().split(':')[4]
  405. if client_addr in self.conn_dict.keys():
  406. print('start install application %s to ("%s":"%s")' % (app_name, client_addr[0], client_addr[1]))
  407. res = install(self.conn_dict[client_addr], app_name, app_file)
  408. if self.message_queues[s] is not None:
  409. logging.info("response {} to cmd server".format(res))
  410. self.message_queues[s].put(bytes(res, encoding = 'utf8'))
  411. if s not in self.outputs:
  412. self.outputs.append(s)
  413. elif(data.decode().split(':')[0] == "uninstall"):
  414. client_addr = (data.decode().split(':')[1],int(data.decode().split(':')[2]))
  415. app_name = data.decode().split(':')[3]
  416. if client_addr in self.conn_dict.keys():
  417. print("start uninstall")
  418. res = uninstall(self.conn_dict[client_addr], app_name)
  419. if self.message_queues[s] is not None:
  420. logging.info("response {} to cmd server".format(res))
  421. self.message_queues[s].put(bytes(res, encoding = 'utf8'))
  422. if s not in self.outputs:
  423. self.outputs.append(s)
  424. # if self.message_queues[s] is not None:
  425. # self.message_queues[s].put(data)
  426. # if s not in self.outputs:
  427. # self.outputs.append(s)
  428. else:
  429. logging.warning(data)
  430. # Interpret empty result as closed connection
  431. try:
  432. for dev in self.devices:
  433. if s == dev.conn:
  434. self.devices.remove(dev)
  435. # Stop listening for input on the connection
  436. if s in self.outputs:
  437. self.outputs.remove(s)
  438. self.inputs.remove(s)
  439. # Remove message queue
  440. if s in self.message_queues.keys():
  441. del self.message_queues[s]
  442. s.close()
  443. except OSError as e:
  444. logging.error("OSError raised, unknown connection")
  445. return "got it"
  446. def handler_send(self, writable):
  447. # Handle outputs
  448. for s in writable:
  449. try:
  450. message_queue = self.message_queues.get(s)
  451. send_data = ''
  452. if message_queue is not None:
  453. send_data = message_queue.get_nowait()
  454. except queue.Empty:
  455. self.outputs.remove(s)
  456. else:
  457. # print "sending %s to %s " % (send_data, s.getpeername)
  458. # print "send something"
  459. if message_queue is not None:
  460. s.send(send_data)
  461. else:
  462. print("client has closed")
  463. # del message_queues[s]
  464. # writable.remove(s)
  465. # print "Client %s disconnected" % (client_address)
  466. return "got it"
  467. def handler_exception(self, exceptional):
  468. # # Handle "exceptional conditions"
  469. for s in exceptional:
  470. print('exception condition on', s.getpeername())
  471. # Stop listening for input on the connection
  472. self.inputs.remove(s)
  473. if s in self.outputs:
  474. self.outputs.remove(s)
  475. s.close()
  476. # Remove message queue
  477. del self.message_queues[s]
  478. return "got it"
  479. def event_loop(tcpserver, inputs, outputs):
  480. while inputs:
  481. # Wait for at least one of the sockets to be ready for processing
  482. print('waiting for the next event')
  483. readable, writable, exceptional = select.select(inputs, outputs, inputs)
  484. if readable is not None:
  485. tcp_recever = tcpserver.handler_recever(readable)
  486. if tcp_recever == 'got it':
  487. print("server have received")
  488. if writable is not None:
  489. tcp_send = tcpserver.handler_send(writable)
  490. if tcp_send == 'got it':
  491. print("server have send")
  492. if exceptional is not None:
  493. tcp_exception = tcpserver.handler_exception(exceptional)
  494. if tcp_exception == 'got it':
  495. print("server have exception")
  496. sleep(0.1)
  497. def run_wasm_server():
  498. server_address = ('localhost', 8888)
  499. server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  500. server.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)
  501. inputs = [server]
  502. outputs = []
  503. message_queues = {}
  504. tcpserver = TCPServer(server, server_address, inputs, outputs, message_queues)
  505. task = threading.Thread(target=event_loop,args=(tcpserver,inputs,outputs))
  506. task.start()
  507. if __name__ == '__main__':
  508. logging.basicConfig(level=logging.DEBUG,
  509. filename='wasm_server.log',
  510. filemode='a',
  511. format=
  512. '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
  513. )
  514. server_address = ('0.0.0.0', 8888)
  515. server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  516. server.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)
  517. inputs = [server]
  518. outputs = []
  519. message_queues = {}
  520. tcpserver = TCPServer(server, server_address, inputs, outputs, message_queues)
  521. logging.info("TCP Server start at {}:{}".format(server_address[0], "8888"))
  522. task = threading.Thread(target=event_loop,args=(tcpserver,inputs,outputs))
  523. task.start()
  524. # event_loop(tcpserver, inputs, outputs)