testnn.go 895 B

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. package main
  2. import (
  3. "fmt"
  4. "git.wecise.com/wecise/common/alg/nn"
  5. lg "log"
  6. "math"
  7. "os"
  8. "path/filepath"
  9. "strings"
  10. )
  11. func main(){
  12. pwd, err := os.Getwd()
  13. if err != nil {
  14. lg.Fatal(err)
  15. }
  16. outFile,_ := os.OpenFile(filepath.Join(pwd, "sin.out"), os.O_CREATE | os.O_RDWR, 0777)
  17. defer outFile.Close()
  18. nn := gonn.DefaultNetwork(1, 10, 1, true)
  19. trainInputs := make([][]float64, 100)
  20. trainTargets := make([][]float64, 100)
  21. for i := 0; i < len(trainInputs); i++ {
  22. trainInputs[i] = []float64{float64(i) / 20.0}
  23. trainTargets[i] = []float64{math.Sin(trainInputs[i][0])}
  24. }
  25. trace, err := nn.Train(trainInputs, trainTargets, 1000)
  26. if err != nil {
  27. lg.Fatal(err)
  28. }
  29. fmt.Print(strings.TrimSpace(trace))
  30. for i := 0; i < len(trainInputs); i++ {
  31. x := []float64{float64(i) / 23.0}
  32. fmt.Fprintln(outFile, x[0], nn.Forward(x)[0])
  33. }
  34. //gonn.DumpNN(filepath.Join(pwd, "nn.dump"), nn)
  35. }