diff options
author | Přemysl Eric Janouch <p@janouch.name> | 2024-01-18 09:38:46 +0100 |
---|---|---|
committer | Přemysl Eric Janouch <p@janouch.name> | 2024-01-18 18:31:10 +0100 |
commit | 8df76dbaab2912e86059f8c7d8e4d2abf350a5d3 (patch) | |
tree | 2d8349237665bd2c5800a55339a6df0e02ea86c3 /deeptagger/deeptagger.cpp | |
parent | 819d2d80e0524dbd9bf1616e0db08818100af7a1 (diff) | |
download | gallery-8df76dbaab2912e86059f8c7d8e4d2abf350a5d3.tar.gz gallery-8df76dbaab2912e86059f8c7d8e4d2abf350a5d3.tar.xz gallery-8df76dbaab2912e86059f8c7d8e4d2abf350a5d3.zip |
Make consistent batches a simple edit
Diffstat (limited to 'deeptagger/deeptagger.cpp')
-rw-r--r-- | deeptagger/deeptagger.cpp | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/deeptagger/deeptagger.cpp b/deeptagger/deeptagger.cpp index cb28d92..4e8f61e 100644 --- a/deeptagger/deeptagger.cpp +++ b/deeptagger/deeptagger.cpp @@ -255,7 +255,9 @@ static void run(std::vector<Magick::Image> &images, const Config &config, Ort::Session &session, std::vector<int64_t> shape) { - auto batch = shape[0] = images.size(); + // For consistency, this value may be bumped to always be g.batch, + // but it does not seem to have an effect on anything. + shape[0] = images.size(); Ort::AllocatorWithDefaultOptions allocator; auto tensor = Ort::Value::CreateTensor<float>( @@ -263,7 +265,7 @@ run(std::vector<Magick::Image> &images, const Config &config, auto input_len = tensor.GetTensorTypeAndShapeInfo().GetElementCount(); auto input_data = tensor.GetTensorMutableData<float>(), pi = input_data; - for (int64_t i = 0; i < batch; i++) { + for (int64_t i = 0; i < images.size(); i++) { switch (config.shape) { case Config::Shape::NCHW: pi = image_to_nchw(pi, images.at(i), config.channels); @@ -296,12 +298,12 @@ run(std::vector<Magick::Image> &images, const Config &config, auto output_len = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount(); auto output_data = outputs.front().GetTensorData<float>(), po = output_data; - if (output_len != batch * config.tags.size()) { + if (output_len != shape[0] * config.tags.size()) { fprintf(stderr, "Tags don't match the output\n"); return; } - for (size_t i = 0; i < batch; i++) { + for (size_t i = 0; i < images.size(); i++) { for (size_t t = 0; t < config.tags.size(); t++) { float value = *po++; if (config.sigmoid) @@ -616,8 +618,6 @@ infer(Ort::Env &env, const char *path, const std::vector<std::string> &images) ctx.output_cv.wait(output_lock, [&]{ return ctx.output.size() == g.batch || ctx.done == workers; }); - // It would be possible to add dummy entries to the batch, - // so that the model doesn't need to be rebuilt. if (!ctx.output.empty()) { run(ctx.output, config, session, shape); ctx.output.clear(); |