| 1234567891011121314151617181920212223242526272829303132333435363738394041 |
- package main
- import (
- "fmt"
- "git.wecise.com/wecise/common/alg/nn"
- lg "log"
- "math"
- "os"
- "path/filepath"
- "strings"
- )
- func main(){
- pwd, err := os.Getwd()
- if err != nil {
- lg.Fatal(err)
- }
- outFile,_ := os.OpenFile(filepath.Join(pwd, "sin.out"), os.O_CREATE | os.O_RDWR, 0777)
- defer outFile.Close()
- nn := gonn.DefaultNetwork(1, 10, 1, true)
- trainInputs := make([][]float64, 100)
- trainTargets := make([][]float64, 100)
- for i := 0; i < len(trainInputs); i++ {
- trainInputs[i] = []float64{float64(i) / 20.0}
- trainTargets[i] = []float64{math.Sin(trainInputs[i][0])}
- }
- trace, err := nn.Train(trainInputs, trainTargets, 1000)
- if err != nil {
- lg.Fatal(err)
- }
- fmt.Print(strings.TrimSpace(trace))
- for i := 0; i < len(trainInputs); i++ {
- x := []float64{float64(i) / 23.0}
- fmt.Fprintln(outFile, x[0], nn.Forward(x)[0])
- }
- //gonn.DumpNN(filepath.Join(pwd, "nn.dump"), nn)
- }
|