aboutsummaryrefslogtreecommitdiff
path: root/deeptagger/deeptagger.cpp
diff options
context:
space:
mode:
authorPřemysl Eric Janouch <p@janouch.name>2024-01-18 09:38:46 +0100
committerPřemysl Eric Janouch <p@janouch.name>2024-01-18 18:31:10 +0100
commit8df76dbaab2912e86059f8c7d8e4d2abf350a5d3 (patch)
tree2d8349237665bd2c5800a55339a6df0e02ea86c3 /deeptagger/deeptagger.cpp
parent819d2d80e0524dbd9bf1616e0db08818100af7a1 (diff)
downloadgallery-8df76dbaab2912e86059f8c7d8e4d2abf350a5d3.tar.gz
gallery-8df76dbaab2912e86059f8c7d8e4d2abf350a5d3.tar.xz
gallery-8df76dbaab2912e86059f8c7d8e4d2abf350a5d3.zip
Make consistent batches a simple edit
Diffstat (limited to 'deeptagger/deeptagger.cpp')
-rw-r--r--deeptagger/deeptagger.cpp12
1 files changed, 6 insertions, 6 deletions
diff --git a/deeptagger/deeptagger.cpp b/deeptagger/deeptagger.cpp
index cb28d92..4e8f61e 100644
--- a/deeptagger/deeptagger.cpp
+++ b/deeptagger/deeptagger.cpp
@@ -255,7 +255,9 @@ static void
run(std::vector<Magick::Image> &images, const Config &config,
Ort::Session &session, std::vector<int64_t> shape)
{
- auto batch = shape[0] = images.size();
+ // For consistency, this value may be bumped to always be g.batch,
+ // but it does not seem to have an effect on anything.
+ shape[0] = images.size();
Ort::AllocatorWithDefaultOptions allocator;
auto tensor = Ort::Value::CreateTensor<float>(
@@ -263,7 +265,7 @@ run(std::vector<Magick::Image> &images, const Config &config,
auto input_len = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
auto input_data = tensor.GetTensorMutableData<float>(), pi = input_data;
- for (int64_t i = 0; i < batch; i++) {
+ for (int64_t i = 0; i < images.size(); i++) {
switch (config.shape) {
case Config::Shape::NCHW:
pi = image_to_nchw(pi, images.at(i), config.channels);
@@ -296,12 +298,12 @@ run(std::vector<Magick::Image> &images, const Config &config,
auto output_len = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
auto output_data = outputs.front().GetTensorData<float>(), po = output_data;
- if (output_len != batch * config.tags.size()) {
+ if (output_len != shape[0] * config.tags.size()) {
fprintf(stderr, "Tags don't match the output\n");
return;
}
- for (size_t i = 0; i < batch; i++) {
+ for (size_t i = 0; i < images.size(); i++) {
for (size_t t = 0; t < config.tags.size(); t++) {
float value = *po++;
if (config.sigmoid)
@@ -616,8 +618,6 @@ infer(Ort::Env &env, const char *path, const std::vector<std::string> &images)
ctx.output_cv.wait(output_lock,
[&]{ return ctx.output.size() == g.batch || ctx.done == workers; });
- // It would be possible to add dummy entries to the batch,
- // so that the model doesn't need to be rebuilt.
if (!ctx.output.empty()) {
run(ctx.output, config, session, shape);
ctx.output.clear();