diff --git a/gpt4all-backend/bert.cpp b/gpt4all-backend/bert.cpp index 31dc7d98..f9e16cb4 100644 --- a/gpt4all-backend/bert.cpp +++ b/gpt4all-backend/bert.cpp @@ -317,7 +317,7 @@ void bert_eval( }; struct ggml_context *ctx0 = ggml_init(params); - struct ggml_cgraph gf = {}; + struct ggml_cgraph *gf = ggml_new_graph(ctx0); // Embeddings. word_embeddings + token_type_embeddings + position_embeddings struct ggml_tensor *token_layer = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); @@ -448,10 +448,10 @@ void bert_eval( ggml_tensor *output = inpL; // run the computation - ggml_build_forward_expand(&gf, output); + ggml_build_forward_expand(gf, output); //ggml_graph_compute_g4a() - ggml_graph_compute_g4a(ctx->work_buf, &gf, n_threads); - //ggml_graph_compute(ctx0, &gf); + ggml_graph_compute_g4a(ctx->work_buf, gf, n_threads); + //ggml_graph_compute(ctx0, gf); // float *dat = ggml_get_data_f32(output); @@ -460,7 +460,7 @@ void bert_eval( #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) // requires GGML_PERF to be defined - ggml_graph_print(&gf); + ggml_graph_print(gf); #endif if (!mem_req_mode) { diff --git a/gpt4all-backend/gptj.cpp b/gpt4all-backend/gptj.cpp index 5031ecdf..7825e6bc 100644 --- a/gpt4all-backend/gptj.cpp +++ b/gpt4all-backend/gptj.cpp @@ -343,7 +343,7 @@ bool gptj_eval( }; struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph gf = {}; + struct ggml_cgraph * gf = ggml_new_graph(ctx0); // KQ_pos - contains the positions struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); @@ -395,8 +395,8 @@ bool gptj_eval( ( n_ctx)*ggml_element_size(model.kv_self.v), (il*n_ctx)*ggml_element_size(model.kv_self.v)*n_embd + n_past*ggml_element_size(model.kv_self.v)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) @@ -515,22 +515,22 @@ bool gptj_eval( // logits -> probs //inpL = ggml_soft_max(ctx0, inpL); - ggml_build_forward_expand(&gf, inpL); + ggml_build_forward_expand(gf, inpL); // run the computation { std::unique_ptr data; - auto plan = ggml_graph_plan(&gf, n_threads); + auto plan = ggml_graph_plan(gf, n_threads); if (plan.work_size > 0) { data.reset(new uint8_t[plan.work_size]); plan.work_data = data.get(); } - ggml_graph_compute(&gf, &plan); + ggml_graph_compute(gf, &plan); } //if (n_past%100 == 0) { - // ggml_graph_print (&gf); - // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); + // ggml_graph_print (gf); + // ggml_graph_dump_dot(gf, NULL, "gpt-2.dot"); //} //embd_w.resize(n_vocab*N);