From 44738ce18c2fe79b03707ef73f2697e109d31edf Mon Sep 17 00:00:00 2001 From: Yen-Hsiang Chang Date: Sun, 7 Jun 2020 14:03:27 -0500 Subject: [PATCH 1/4] add autograd and cupti --- Gopkg.lock | 9 ++--- predictor.go | 71 ++++++++++++++++++++++++++++++++++++++- trace.go | 95 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 170 insertions(+), 5 deletions(-) create mode 100644 trace.go diff --git a/Gopkg.lock b/Gopkg.lock index 52d08c2..59e78e2 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -705,7 +705,7 @@ [[projects]] branch = "master" - digest = "1:a6cb3e904be38e622c5c1fb9271018c37f15ab09bdf9becb1fe0d4f5377afd85" + digest = "0:" name = "github.com/rai-project/dlframework" packages = [ ".", @@ -713,7 +713,7 @@ "framework/options", ] pruneopts = "T" - revision = "d642f9c121d88be304eb85e05e53aca3428d763b" + revision = "4aaa12cfa5874a785cbe70a2bc8a849472c7bc95" [[projects]] branch = "master" @@ -784,7 +784,7 @@ [[projects]] branch = "master" - digest = "1:ffc0e7b5e9b5314623422c9b6e46542bd8e8866f749bd5dbe825db4867822584" + digest = "0:" name = "github.com/rai-project/tracer" packages = [ ".", @@ -797,7 +797,7 @@ "zipkin", ] pruneopts = "T" - revision = "c5fab9c8969a5dc405c2d79caa909ac7501dbf27" + revision = "97d6a9677dc3a0e73fc9df74d366323884c19190" [[projects]] branch = "master" @@ -1437,6 +1437,7 @@ "github.com/anthonynsimon/bild/transform", "github.com/benesch/cgosymbolizer", "github.com/k0kubun/pp", + "github.com/opentracing/opentracing-go", "github.com/pkg/errors", "github.com/rai-project/config", "github.com/rai-project/dlframework", diff --git a/predictor.go b/predictor.go index 158d456..b020735 100644 --- a/predictor.go +++ b/predictor.go @@ -10,11 +10,15 @@ import ( "context" "fmt" "runtime" + "strings" + "time" "unsafe" "github.com/Unknwon/com" + "github.com/k0kubun/pp" "github.com/pkg/errors" "github.com/rai-project/dlframework/framework/options" + cupti "github.com/rai-project/go-cupti" nvidiasmi "github.com/rai-project/nvidia-smi" "github.com/rai-project/tracer" "gorgonia.org/tensor" @@ -24,6 +28,7 @@ type Predictor struct { ctx C.Torch_PredictorContext inputs []C.Torch_TensorContext options *options.Options + cu *cupti.CUPTI } func New(ctx context.Context, opts ...options.Option) (*Predictor, error) { @@ -86,11 +91,45 @@ func (p *Predictor) Predict(ctx context.Context, inputs []tensor.Tensor) error { inputSlice[ii] = toTensorCtx(dense, fromDevice(p.options)) } - predictSpan, _ := tracer.StartSpanFromContext(ctx, tracer.MODEL_TRACE, "c_predict") + predictSpan, ctx := tracer.StartSpanFromContext(ctx, tracer.MODEL_TRACE, "c_predict") defer predictSpan.Finish() + if p.options.TraceLevel() >= tracer.FRAMEWORK_TRACE { + p.EnableProfiling() + start_time := time.Now().UnixNano() + err := p.StartProfiling("pytorch", "predict") + if err != nil { + log.WithError(err).WithField("framework", "pytorch").Error("unable to start framework profiling") + } else { + defer func() { + p.EndProfiling() + end_time := time.Now().UnixNano() + + profBuffer, err := p.ReadProfile() + if err != nil { + pp.Println(err) + return + } + t, err := NewTrace(profBuffer, start_time, end_time) + if err != nil { + panic(err) + return + } + t.Publish(ctx, tracer.FRAMEWORK_TRACE) + p.DisableProfiling() + }() + } + } + + err := p.cuptiStart(ctx) + if err != nil { + return err + } + C.Torch_PredictorRun(p.ctx, &inputSlice[0], C.int(inputsLength)) + p.cuptiClose() + return GetError() } @@ -135,6 +174,36 @@ func (p *Predictor) Close() { p.finalize() } +func (p *Predictor) cuptiStart(ctx context.Context) error { + if p.options.TraceLevel() < tracer.SYSTEM_LIBRARY_TRACE { + return nil + } + metrics := []string{} + if p.options.GPUMetrics() != "" { + metrics = strings.Split(p.options.GPUMetrics(), ",") + } + + cu, err := cupti.New(cupti.Context(ctx), + cupti.SamplingPeriod(0), + cupti.Metrics(metrics), + ) + if err != nil { + return err + } + + p.cu = cu + return nil +} + +func (p *Predictor) cuptiClose() { + if p.cu == nil { + return + } + p.cu.Wait() + p.cu.Close() + p.cu = nil +} + func init() { C.InitPytorch() } diff --git a/trace.go b/trace.go new file mode 100644 index 0000000..05e5b4b --- /dev/null +++ b/trace.go @@ -0,0 +1,95 @@ +package pytorch + +import ( + "context" + "encoding/json" + "fmt" + "time" + + opentracing "github.com/opentracing/opentracing-go" + "github.com/rai-project/tracer" +) + +type TraceEvent struct { + Name string `json:"name,omitempty"` + Phase string `json:"ph,omitempty"` + Timestamp float32 `json:"ts,omitempty"` + Duration float32 `json:"dur,omitempty"` + ProcessID string `json:"pid,omitempty"` + ThreadID int64 `json:"tid,omitempty"` + Start int64 `json:"-"` + End int64 `json:"-"` + StartTime time.Time `json:"-"` + EndTime time.Time `json:"-"` +} + +func (t TraceEvent) ID() string { + return fmt.Sprintf("%s/%v", t.Name, t.ThreadID) +} + +type TraceEvents []TraceEvent + +func (t TraceEvents) Len() int { return len(t) } +func (t TraceEvents) Swap(i, j int) { t[i], t[j] = t[j], t[i] } +func (t TraceEvents) Less(i, j int) bool { return t[i].Start < t[j].Start } + +type Trace struct { + StartTime time.Time + EndTime time.Time + TraceEvents TraceEvents +} + +func (t Trace) Len() int { return t.TraceEvents.Len() } +func (t Trace) Swap(i, j int) { t.TraceEvents.Swap(i, j) } +func (t Trace) Less(i, j int) bool { return t.TraceEvents.Less(i, j) } + +func NewTrace(data string, start_time int64, end_time int64) (*Trace, error) { + trace := new(Trace) + err := json.Unmarshal([]byte(data), &trace.TraceEvents) + if err != nil { + return nil, err + } + trace.StartTime = time.Unix(0, start_time) + trace.EndTime = time.Unix(0, end_time) + for ii, event := range trace.TraceEvents { + trace.TraceEvents[ii].Start = start_time + int64(event.Timestamp * 1000) + trace.TraceEvents[ii].StartTime = time.Unix(0, trace.TraceEvents[ii].Start) + trace.TraceEvents[ii].End = start_time + int64(event.Timestamp * 1000 + event.Duration * 1000) + trace.TraceEvents[ii].EndTime = time.Unix(0, trace.TraceEvents[ii].End) + } + return trace, nil +} + +func (event *TraceEvent) Publish(ctx context.Context, lvl tracer.Level, opts ...opentracing.StartSpanOption) error { + tags := opentracing.Tags{ + "phase": event.Phase, + "process_id": event.ProcessID, + "thread_id": event.ThreadID, + } + s, _ := tracer.StartSpanFromContext( + ctx, + lvl, + event.Name, + opentracing.StartTime(event.StartTime), + tags, + ) + if s == nil { + log.WithField("event_name", event.Name). + WithField("tags", tags). + Error("failed to create span from context") + return nil + } + s.FinishWithOptions(opentracing.FinishOptions{ + FinishTime: event.EndTime, + }) + return nil +} + +func (t *Trace) Publish(ctx context.Context, lvl tracer.Level, opts ...opentracing.StartSpanOption) error { + for _, event := range t.TraceEvents { + if err := event.Publish(ctx, lvl, opts...); err != nil { + return err + } + } + return nil +} From f3ab434d6d34550731fd5526d5fd738c1c2c849b Mon Sep 17 00:00:00 2001 From: Yen-Hsiang Chang Date: Tue, 9 Jun 2020 21:02:02 -0500 Subject: [PATCH 2/4] move the timestamp to the time after constructing profiler --- cbits/predictor.hpp | 2 ++ predictor.cpp | 15 +++++++++++++++ predictor.go | 7 +++---- trace.go | 8 +++----- 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/cbits/predictor.hpp b/cbits/predictor.hpp index 1ed9323..0a4f86e 100644 --- a/cbits/predictor.hpp +++ b/cbits/predictor.hpp @@ -137,6 +137,8 @@ void Torch_ProfilingDisable(Torch_PredictorContext pred); char* Torch_ProfilingRead(Torch_PredictorContext pred); +int64_t Torch_ProfilingGetStartTime(Torch_PredictorContext pred); + // JIT #ifdef ENABLE_PYTROCH_JIT Torch_JITModuleContext Torch_CompileTorchScript(char* script, Torch_Error* error); diff --git a/predictor.cpp b/predictor.cpp index 88a77ef..3d14015 100644 --- a/predictor.cpp +++ b/predictor.cpp @@ -5,6 +5,7 @@ #include "timer.impl.hpp" #include +#include #include #include #include @@ -35,6 +36,7 @@ class Predictor { profile *prof_{nullptr}; std::string profile_filename_{"profile.trace"}; bool profile_enabled_{false}; + int64_t profile_start; }; Predictor::Predictor(const string &model_file, Torch_DeviceKind device) { @@ -66,6 +68,7 @@ void Predictor::Predict(Torch_TensorContext *cInputs, int inputLength) { if (profile_enabled_ == true) { autograd::profiler::RecordProfile guard(profile_filename_); + profile_start = static_cast(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); output_ = net_.forward(inputs); return; } @@ -200,3 +203,15 @@ char *Torch_ProfilingRead(Torch_PredictorContext pred) { END_HANDLE_TH_ERRORS(Torch_GlobalError, (char *)0); } + +int64_t Torch_ProfilingGetStartTime(Torch_PredictorContext pred) { + HANDLE_TH_ERRORS(Torch_GlobalError); + auto predictor = (Predictor *)pred; + if (predictor == nullptr) { + return 0; + } + + return predictor->profile_start; + END_HANDLE_TH_ERRORS(Torch_GlobalError, 0); +} + diff --git a/predictor.go b/predictor.go index b020735..7e5b012 100644 --- a/predictor.go +++ b/predictor.go @@ -11,7 +11,6 @@ import ( "fmt" "runtime" "strings" - "time" "unsafe" "github.com/Unknwon/com" @@ -96,21 +95,21 @@ func (p *Predictor) Predict(ctx context.Context, inputs []tensor.Tensor) error { if p.options.TraceLevel() >= tracer.FRAMEWORK_TRACE { p.EnableProfiling() - start_time := time.Now().UnixNano() err := p.StartProfiling("pytorch", "predict") if err != nil { log.WithError(err).WithField("framework", "pytorch").Error("unable to start framework profiling") } else { defer func() { p.EndProfiling() - end_time := time.Now().UnixNano() + + start_time := int64(C.Torch_ProfilingGetStartTime(p.ctx)) profBuffer, err := p.ReadProfile() if err != nil { pp.Println(err) return } - t, err := NewTrace(profBuffer, start_time, end_time) + t, err := NewTrace(profBuffer, start_time) if err != nil { panic(err) return diff --git a/trace.go b/trace.go index 05e5b4b..16fcef5 100644 --- a/trace.go +++ b/trace.go @@ -35,7 +35,6 @@ func (t TraceEvents) Less(i, j int) bool { return t[i].Start < t[j].Start } type Trace struct { StartTime time.Time - EndTime time.Time TraceEvents TraceEvents } @@ -43,18 +42,17 @@ func (t Trace) Len() int { return t.TraceEvents.Len() } func (t Trace) Swap(i, j int) { t.TraceEvents.Swap(i, j) } func (t Trace) Less(i, j int) bool { return t.TraceEvents.Less(i, j) } -func NewTrace(data string, start_time int64, end_time int64) (*Trace, error) { +func NewTrace(data string, start_time int64) (*Trace, error) { trace := new(Trace) err := json.Unmarshal([]byte(data), &trace.TraceEvents) if err != nil { return nil, err } trace.StartTime = time.Unix(0, start_time) - trace.EndTime = time.Unix(0, end_time) for ii, event := range trace.TraceEvents { - trace.TraceEvents[ii].Start = start_time + int64(event.Timestamp * 1000) + trace.TraceEvents[ii].Start = start_time + int64(event.Timestamp*1000) trace.TraceEvents[ii].StartTime = time.Unix(0, trace.TraceEvents[ii].Start) - trace.TraceEvents[ii].End = start_time + int64(event.Timestamp * 1000 + event.Duration * 1000) + trace.TraceEvents[ii].End = start_time + int64(event.Timestamp*1000+event.Duration*1000) trace.TraceEvents[ii].EndTime = time.Unix(0, trace.TraceEvents[ii].End) } return trace, nil From ff5d61f6a8a8209d52a0ecf4e9254e25eb2727d9 Mon Sep 17 00:00:00 2001 From: Yen-Hsiang Chang Date: Sun, 28 Jun 2020 16:44:58 -0500 Subject: [PATCH 3/4] fix typo --- predictor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/predictor.cpp b/predictor.cpp index 3d14015..ea94376 100644 --- a/predictor.cpp +++ b/predictor.cpp @@ -68,7 +68,7 @@ void Predictor::Predict(Torch_TensorContext *cInputs, int inputLength) { if (profile_enabled_ == true) { autograd::profiler::RecordProfile guard(profile_filename_); - profile_start = static_cast(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + profile_start = static_cast(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); output_ = net_.forward(inputs); return; } From 9384fbd9f4e2778e9ecf00d428be05e84c0f2f1b Mon Sep 17 00:00:00 2001 From: Yen-Hsiang Chang Date: Sat, 22 Aug 2020 13:00:20 -0500 Subject: [PATCH 4/4] fix some of the memory leak --- Gopkg.lock | 237 ++++---------------------------------------------- Gopkg.toml | 8 ++ predictor.cpp | 9 +- predictor.go | 9 +- 4 files changed, 34 insertions(+), 229 deletions(-) diff --git a/Gopkg.lock b/Gopkg.lock index 59e78e2..822a2f2 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -42,14 +42,6 @@ revision = "805c489aa98f412e79eb308a37996bf9d8b1c91e" version = "v1.5.0" -[[projects]] - digest = "1:56b22d8995bee726d179c35b1693da4627ba6d755627f82119197e56cb92c3ed" - name = "github.com/Shopify/sarama" - packages = ["."] - pruneopts = "T" - revision = "1358e9c6e61694cd61b2daae79f5aa4b8073c976" - version = "v1.24.0" - [[projects]] digest = "1:e92f5581902c345eb4ceffdcd4a854fb8f73cf436d47d837d1ec98ef1fe0a214" name = "github.com/StackExchange/wmi" @@ -80,14 +72,6 @@ revision = "35fd1904d43f809ded0959922164885e25577aa2" version = "v0.10.0" -[[projects]] - digest = "1:e4219cafc90c03296a8a144fc58f8334d56c1ce18258b6634cd547f4eeff6862" - name = "github.com/apache/thrift" - packages = ["lib/go/thrift"] - pruneopts = "T" - revision = "cecee50308fc7e6f77f55b3fd906c1c6c471fa2f" - version = "v0.13.0" - [[projects]] branch = "master" digest = "1:1237f2e0fe3256bbce813cc1bafddaf26b85bddbd2f8c08c30eb845140198aee" @@ -206,38 +190,6 @@ revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" version = "v1.1.1" -[[projects]] - digest = "1:f176f62e48b0b01d0c6c424176fb258b382c4651050de0e0b37b29dda56d9e0c" - name = "github.com/dustin/go-humanize" - packages = ["."] - pruneopts = "T" - revision = "9f541cc9db5d55bce703bd99987c9d5cb8eea45e" - version = "v1.0.0" - -[[projects]] - digest = "1:8480f60ae4bd46d06511a2577565acd0e60110899240c7e54be1e4b0bd55ed06" - name = "github.com/eapache/go-resiliency" - packages = ["breaker"] - pruneopts = "T" - revision = "5efd2ed019fd331ec2defc6f3bd98882f1e3e636" - version = "v1.2.0" - -[[projects]] - branch = "master" - digest = "1:5b9eecdd952ee665ce94669cd1074c4fe5fb389a1095a42133ed0ad7ce1e81bd" - name = "github.com/eapache/go-xerial-snappy" - packages = ["."] - pruneopts = "T" - revision = "776d5712da21bc4762676d614db1d8a64f4238b0" - -[[projects]] - digest = "1:444b82bfe35c83bbcaf84e310fb81a1f9ece03edfed586483c869e2c046aef69" - name = "github.com/eapache/queue" - packages = ["."] - pruneopts = "T" - revision = "44cc805cf13205b55f69e14bcb69867d1ae92f98" - version = "v1.1.0" - [[projects]] branch = "master" digest = "1:9a023ac8c30a6871d3fc94c9c0b5a0639ede096c351fd6cdcd7268fe9cb95af2" @@ -270,14 +222,6 @@ revision = "c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9" version = "v1.4.7" -[[projects]] - digest = "1:4062bc6de62d73e2be342243cf138cf499b34d558876db8d9430e2149388a4d8" - name = "github.com/go-logfmt/logfmt" - packages = ["."] - pruneopts = "T" - revision = "07c9b44f60d7ffdfb7d8efe1ad539965737836dc" - version = "v0.4.0" - [[projects]] digest = "1:721506c3d88697ba8229dec8530cdc02b3efdb1a5c12e5de7e7e70cb41ba3a4e" name = "github.com/go-ole/go-ole" @@ -335,14 +279,6 @@ pruneopts = "T" revision = "ed6926b37a637426117ccab59282c3839528a700" -[[projects]] - digest = "1:f37c069fadbaa889f79aea9cd463d225000a0d26f1e7423f14c0cd58fbc5c597" - name = "github.com/golang/snappy" - packages = ["."] - pruneopts = "T" - revision = "2a8bb927dd31d8daada140a5d09578521ce5c36a" - version = "v0.0.1" - [[projects]] digest = "1:810db00a0be338cd083e07fbae010c118fc94dfe4db53b4f05b9918fb3afcb99" name = "github.com/google/flatbuffers" @@ -398,14 +334,6 @@ revision = "c9a55de4fe06c920a71964b53cfe3dd293a3c743" version = "v1.0.0" -[[projects]] - digest = "1:f14364057165381ea296e49f8870a9ffce2b8a95e34d6ae06c759106aaef428c" - name = "github.com/hashicorp/go-uuid" - packages = ["."] - pruneopts = "T" - revision = "4f571afc59f3043a65f8fe6bf46d887b10a01d43" - version = "v1.0.1" - [[projects]] digest = "1:88e0b0baeb9072f0a4afbcf12dda615fc8be001d1802357538591155998da21b" name = "github.com/hashicorp/go-version" @@ -433,6 +361,14 @@ revision = "8cb6e5b959231cc1119e43259c4a608f9c51a241" version = "v1.0.0" +[[projects]] + digest = "0:" + name = "github.com/iancoleman/strcase" + packages = ["."] + pruneopts = "T" + revision = "23e9d4e5c09d4767bb6a4d9fdc49a9d548b70898" + version = "v0.1.0" + [[projects]] branch = "master" digest = "1:e345ab0697f8f63d0ff3cc4c4c90fa470fa79c8d3c0b461a1d16df2f5b0c1fd1" @@ -449,17 +385,6 @@ pruneopts = "T" revision = "5e5cf60278f657d30daa329dd0e7e893b6b8f027" -[[projects]] - digest = "1:569e67119c5cf9c8fc8abc8c763db09df12530cdadd87f93d33e3f7141835810" - name = "github.com/jcmturner/gofork" - packages = [ - "encoding/asn1", - "x/crypto/pbkdf2", - ] - pruneopts = "T" - revision = "dc7c13fece037a4a36e2b3c69db4991498d30692" - version = "v1.0.0" - [[projects]] digest = "1:b3f2dea8fe8eadb5833f7c8a2ef35dae07308bf7a4d0fee7f37774314e1964d3" name = "github.com/jmespath/go-jmespath" @@ -495,20 +420,6 @@ revision = "3d73dea227e0711e38b911ffa6fbafc8ff6b2991" version = "v3.0.1" -[[projects]] - digest = "1:c86485dd30381468d727b19c38674f84da2f734575f0e05dee020bc5db6d2997" - name = "github.com/klauspost/compress" - packages = [ - "fse", - "huff0", - "snappy", - "zstd", - "zstd/internal/xxhash", - ] - pruneopts = "T" - revision = "a41f1a10bd3b167958ff6df80b800fe2969dd3ca" - version = "v1.9.0" - [[projects]] digest = "1:31e761d97c76151dde79e9d28964a812c46efc5baee4085b86f68f0c654450de" name = "github.com/konsorten/go-windows-terminal-sequences" @@ -517,14 +428,6 @@ revision = "f55edac94c9bbba5d6182a4be46d86a2c9b5b50e" version = "v1.0.2" -[[projects]] - branch = "master" - digest = "1:a64e323dc06b73892e5bb5d040ced475c4645d456038333883f58934abbf6f72" - name = "github.com/kr/logfmt" - packages = ["."] - pruneopts = "T" - revision = "b84e30acd515aadc4b783ad4ff83aff3299bdfe0" - [[projects]] digest = "1:d7cc16f6f66fd3f5864ff77480288704b02e5263f6f243dae62b43fdf4bb638e" name = "github.com/magiconair/properties" @@ -638,28 +541,6 @@ revision = "659c90643e714681897ec2521c60567dd21da733" version = "v1.1.0" -[[projects]] - digest = "1:d7d9781d7c86e16db72a1c57661cb4ed66530ae32d1e36133ed94fea42e22d3a" - name = "github.com/openzipkin-contrib/zipkin-go-opentracing" - packages = [ - "flag", - "thrift/gen-go/scribe", - "thrift/gen-go/zipkincore", - "types", - "wire", - ] - pruneopts = "T" - revision = "f0f479ad013a498e4cbfb369414e5d3880903779" - version = "v0.3.5" - -[[projects]] - digest = "1:d7d9781d7c86e16db72a1c57661cb4ed66530ae32d1e36133ed94fea42e22d3a" - name = "github.com/openzipkin/zipkin-go-opentracing" - packages = ["."] - pruneopts = "T" - revision = "f0f479ad013a498e4cbfb369414e5d3880903779" - version = "v0.3.5" - [[projects]] digest = "1:808cdddf087fb64baeae67b8dfaee2069034d9704923a3cb8bd96a995421a625" name = "github.com/patrickmn/go-cache" @@ -676,17 +557,6 @@ revision = "8fe62057ea2d46ce44254c98e84e810044dbe197" version = "v1.5.0" -[[projects]] - digest = "1:8cf21019a14a486cab3b0fd97aa78cdf58bc7efcaf7dc9d1b4b8cb76fefa17d6" - name = "github.com/pierrec/lz4" - packages = [ - ".", - "internal/xxh32", - ] - pruneopts = "T" - revision = "645f9b948eee34cbcc335c70999f79c29c420fbf" - version = "v2.3.0" - [[projects]] digest = "1:cf31692c14422fa27c83a05292eb5cbe0fb2775972e8f1f8446a71549bd8980b" name = "github.com/pkg/errors" @@ -724,15 +594,15 @@ revision = "ba1da121542db77fc12e929de694d3defb5b44f7" [[projects]] - branch = "develop" - digest = "1:0ee5511bee3d685f5e7099fb101a0bc1673a235037cbb1b5ec0fd94eff4d1632" + branch = "master" + digest = "0:" name = "github.com/rai-project/go-cupti" packages = [ ".", "types", ] pruneopts = "T" - revision = "70df194a40553b7fd0ab52b2704b39844cfe8b9a" + revision = "042c8db6b58eea64b98a52e9dc4cc069196abc84" [[projects]] branch = "master" @@ -794,7 +664,6 @@ "noop", "observer", "utils", - "zipkin", ] pruneopts = "T" revision = "97d6a9677dc3a0e73fc9df74d366323884c19190" @@ -831,14 +700,6 @@ pruneopts = "T" revision = "1b01514224a1a60a6bcbb0b6b9d3a00ec14ae17f" -[[projects]] - branch = "master" - digest = "1:61c70a37c48b3085095e21bc86138f58d29cc125a73ad8b54753923a9a3c2043" - name = "github.com/rcrowley/go-metrics" - packages = ["."] - pruneopts = "T" - revision = "cac0b30c2563378d434b5af411844adff8e32960" - [[projects]] digest = "1:ee4f4f6c0f23c0e4ae7bd51b4c146272ea31fd981efff8d8a83e8b30d74b136c" name = "github.com/shirou/gopsutil" @@ -1017,18 +878,16 @@ [[projects]] branch = "master" - digest = "1:054a206d6d76245f60e2d1823b8ca9b205352cff41b7932b98d3416731a3b368" + digest = "0:" name = "golang.org/x/crypto" packages = [ "cast5", - "md4", "openpgp", "openpgp/armor", "openpgp/elgamal", "openpgp/errors", "openpgp/packet", "openpgp/s2k", - "pbkdf2", ] pruneopts = "T" revision = "87dc89f01550277dc22b74ffcf4cd89fa2f40f4c" @@ -1057,7 +916,7 @@ [[projects]] branch = "master" - digest = "1:b9c7a57626be9ad76a41d7e42a7457638bd0bb62959fe97eb61b2e954bc42cbb" + digest = "0:" name = "golang.org/x/net" packages = [ "context", @@ -1066,10 +925,10 @@ "http2", "http2/hpack", "idna", - "internal/socks", "internal/timeseries", - "proxy", "trace", + "webdav", + "webdav/internal/xml", ] pruneopts = "T" revision = "ec77196f6094c3492a8b61f2c11cf937f78992ae" @@ -1279,72 +1138,6 @@ revision = "f6d0f9ee430895e87ef1ceb5ac8f39725bafceef" version = "v1.24.0" -[[projects]] - digest = "1:c902038ee2d6f964d3b9f2c718126571410c5d81251cbab9fe58abd37803513c" - name = "gopkg.in/jcmturner/aescts.v1" - packages = ["."] - pruneopts = "T" - revision = "f6abebb3171c4c1b1fea279cb7c7325020a26290" - version = "v1.0.1" - -[[projects]] - digest = "1:a1a3e185c03d79a7452d5d5b4c91be4cc433f55e6ed3a35233d852c966e39013" - name = "gopkg.in/jcmturner/dnsutils.v1" - packages = ["."] - pruneopts = "T" - revision = "13eeb8d49ffb74d7a75784c35e4d900607a3943c" - version = "v1.0.1" - -[[projects]] - digest = "1:3fc73aba04c19a6ff67e1bb055ec501dba6ec87b9f29062e26e19c8a863f0eba" - name = "gopkg.in/jcmturner/gokrb5.v7" - packages = [ - "asn1tools", - "client", - "config", - "credentials", - "crypto", - "crypto/common", - "crypto/etype", - "crypto/rfc3961", - "crypto/rfc3962", - "crypto/rfc4757", - "crypto/rfc8009", - "gssapi", - "iana", - "iana/addrtype", - "iana/adtype", - "iana/asnAppTag", - "iana/chksumtype", - "iana/errorcode", - "iana/etypeID", - "iana/flags", - "iana/keyusage", - "iana/msgtype", - "iana/nametype", - "iana/patype", - "kadmin", - "keytab", - "krberror", - "messages", - "pac", - "types", - ] - pruneopts = "T" - revision = "363118e62befa8a14ff01031c025026077fe5d6d" - version = "v7.3.0" - -[[projects]] - digest = "1:ab34660806b7f5fff3f808f7c2e3a967474bbd3cde6a16885fe2ee8f47b2b255" - name = "gopkg.in/jcmturner/rpc.v1" - packages = [ - "mstypes", - "ndr", - ] - pruneopts = "T" - revision = "99a8ce2fbf8b8087b6ed12a37c61b10f04070043" - version = "v1.1.0" - [[projects]] branch = "v2" digest = "1:fe8f3f85c6b45a792108158fac787bc64aa433ec2b23800e9222a8a7836a66de" diff --git a/Gopkg.toml b/Gopkg.toml index 9a96737..1dee1b8 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -52,6 +52,10 @@ branch = "master" name = "github.com/rai-project/downloadmanager" +[[constraint]] + branch = "master" + name = "github.com/rai-project/go-cupti" + [[constraint]] branch = "master" name = "github.com/rai-project/logger" @@ -72,6 +76,10 @@ name = "gorgonia.org/tensor" version = "0.9.0-beta" +[[override]] + name = "github.com/k0kubun/pp" + version = "3.0.1" + [prune] go-tests = true diff --git a/predictor.cpp b/predictor.cpp index ea94376..4ba81b1 100644 --- a/predictor.cpp +++ b/predictor.cpp @@ -36,7 +36,7 @@ class Predictor { profile *prof_{nullptr}; std::string profile_filename_{"profile.trace"}; bool profile_enabled_{false}; - int64_t profile_start; + int64_t profile_start_; }; Predictor::Predictor(const string &model_file, Torch_DeviceKind device) { @@ -58,17 +58,20 @@ void Predictor::Predict(Torch_TensorContext *cInputs, int inputLength) { for (int ii = 0; ii < inputLength; ii++) { at::Tensor tensor = reinterpret_cast(cInputs[ii])->tensor; +#ifdef DEBUG std::cout << "tensor dim = " << tensor.dim() << " size = "; for (auto sz : tensor.sizes()) { std::cout << sz << ", "; } std::cout << "\n"; +#endif + inputs.emplace_back(tensor); } if (profile_enabled_ == true) { autograd::profiler::RecordProfile guard(profile_filename_); - profile_start = static_cast(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); + profile_start_ = static_cast(std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()); output_ = net_.forward(inputs); return; } @@ -211,7 +214,7 @@ int64_t Torch_ProfilingGetStartTime(Torch_PredictorContext pred) { return 0; } - return predictor->profile_start; + return predictor->profile_start_; END_HANDLE_TH_ERRORS(Torch_GlobalError, 0); } diff --git a/predictor.go b/predictor.go index 7e5b012..52adecf 100644 --- a/predictor.go +++ b/predictor.go @@ -25,7 +25,6 @@ import ( type Predictor struct { ctx C.Torch_PredictorContext - inputs []C.Torch_TensorContext options *options.Options cu *cupti.CUPTI } @@ -89,6 +88,11 @@ func (p *Predictor) Predict(ctx context.Context, inputs []tensor.Tensor) error { } inputSlice[ii] = toTensorCtx(dense, fromDevice(p.options)) } + defer func() { + for _, input := range inputSlice { + C.Torch_DeleteTensor(input) + } + }() predictSpan, ctx := tracer.StartSpanFromContext(ctx, tracer.MODEL_TRACE, "c_predict") defer predictSpan.Finish() @@ -160,9 +164,6 @@ func (p *Predictor) finalize() { if p == nil { return } - for _, input := range p.inputs { - C.Torch_DeleteTensor(input) - } if p.ctx != nil { C.Torch_PredictorDelete(p.ctx) }