From 8df76dbaab2912e86059f8c7d8e4d2abf350a5d3 Mon Sep 17 00:00:00 2001 From: Přemysl Eric Janouch Date: Thu, 18 Jan 2024 09:38:46 +0100 Subject: Make consistent batches a simple edit --- deeptagger/deeptagger.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'deeptagger/deeptagger.cpp') diff --git a/deeptagger/deeptagger.cpp b/deeptagger/deeptagger.cpp index cb28d92..4e8f61e 100644 --- a/deeptagger/deeptagger.cpp +++ b/deeptagger/deeptagger.cpp @@ -255,7 +255,9 @@ static void run(std::vector &images, const Config &config, Ort::Session &session, std::vector shape) { - auto batch = shape[0] = images.size(); + // For consistency, this value may be bumped to always be g.batch, + // but it does not seem to have an effect on anything. + shape[0] = images.size(); Ort::AllocatorWithDefaultOptions allocator; auto tensor = Ort::Value::CreateTensor( @@ -263,7 +265,7 @@ run(std::vector &images, const Config &config, auto input_len = tensor.GetTensorTypeAndShapeInfo().GetElementCount(); auto input_data = tensor.GetTensorMutableData(), pi = input_data; - for (int64_t i = 0; i < batch; i++) { + for (int64_t i = 0; i < images.size(); i++) { switch (config.shape) { case Config::Shape::NCHW: pi = image_to_nchw(pi, images.at(i), config.channels); @@ -296,12 +298,12 @@ run(std::vector &images, const Config &config, auto output_len = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount(); auto output_data = outputs.front().GetTensorData(), po = output_data; - if (output_len != batch * config.tags.size()) { + if (output_len != shape[0] * config.tags.size()) { fprintf(stderr, "Tags don't match the output\n"); return; } - for (size_t i = 0; i < batch; i++) { + for (size_t i = 0; i < images.size(); i++) { for (size_t t = 0; t < config.tags.size(); t++) { float value = *po++; if (config.sigmoid) @@ -616,8 +618,6 @@ infer(Ort::Env &env, const char *path, const std::vector &images) ctx.output_cv.wait(output_lock, [&]{ return ctx.output.size() == g.batch || ctx.done == workers; }); - // It would be possible to add dummy entries to the batch, - // so that the model doesn't need to be rebuilt. if (!ctx.output.empty()) { run(ctx.output, config, session, shape); ctx.output.clear(); -- cgit v1.2.3-70-g09d2