Support.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os.path
  2. import itertools
  3. import Tools
  4. import random
  5. import numpy as np
  6. NBTESTSAMPLES = 10
  7. # Nb vectors for barycenter
  8. NBVECTORS = [4,10,16]
  9. VECDIM = [12,14,20]
  10. def genWsum(config,nb):
  11. dims=[]
  12. inputs=[]
  13. weights=[]
  14. output=[]
  15. vecDim = VECDIM[nb % len(VECDIM)]
  16. dims.append(NBTESTSAMPLES)
  17. dims.append(vecDim)
  18. for _ in range(0,NBTESTSAMPLES):
  19. va = np.random.rand(vecDim)
  20. vb = np.random.rand(vecDim)
  21. e = np.sum(va.T * vb) / np.sum(vb)
  22. inputs += list(va)
  23. weights += list(vb)
  24. output.append(e)
  25. inputs=np.array(inputs)
  26. weights=np.array(weights)
  27. output=np.array(output)
  28. config.writeInput(nb, inputs,"Inputs")
  29. config.writeInputS16(nb, dims,"Dims")
  30. config.writeInput(nb, weights,"Weights")
  31. config.writeReference(nb, output,"Ref")
  32. def genBarycenter(config,nb):
  33. dims=[]
  34. inputs=[]
  35. weights=[]
  36. output=[]
  37. vecDim = VECDIM[nb % len(VECDIM)]
  38. nbVecs = NBVECTORS[nb % len(NBVECTORS)]
  39. dims.append(NBTESTSAMPLES)
  40. dims.append(nbVecs)
  41. dims.append(vecDim)
  42. for _ in range(0,NBTESTSAMPLES):
  43. vecs = []
  44. b = np.zeros(vecDim)
  45. coefs = np.random.rand(nbVecs)
  46. for i in range(nbVecs):
  47. va = np.random.rand(vecDim)
  48. b += va * coefs[i]
  49. vecs += list(va)
  50. b = b / np.sum(coefs)
  51. inputs += list(vecs)
  52. weights += list(coefs)
  53. output += list(b)
  54. inputs=np.array(inputs)
  55. weights=np.array(weights)
  56. output=np.array(output)
  57. config.writeInput(nb, inputs,"Inputs")
  58. config.writeInputS16(nb, dims,"Dims")
  59. config.writeInput(nb, weights,"Weights")
  60. config.writeReference(nb, output,"Ref")
  61. def writeTests(config):
  62. genBarycenter(config,1)
  63. genWsum(config,2)
  64. PATTERNDIR = os.path.join("Patterns","DSP","Support","Support")
  65. PARAMDIR = os.path.join("Parameters","DSP","Support","Support")
  66. configf32=Tools.Config(PATTERNDIR,PARAMDIR,"f32")
  67. writeTests(configf32)