diff options
Diffstat (limited to 'deeptagger')
| -rw-r--r-- | deeptagger/deeptagger.cpp | 12 | 
1 files changed, 6 insertions, 6 deletions
| diff --git a/deeptagger/deeptagger.cpp b/deeptagger/deeptagger.cpp index cb28d92..4e8f61e 100644 --- a/deeptagger/deeptagger.cpp +++ b/deeptagger/deeptagger.cpp @@ -255,7 +255,9 @@ static void  run(std::vector<Magick::Image> &images, const Config &config,  	Ort::Session &session, std::vector<int64_t> shape)  { -	auto batch = shape[0] = images.size(); +	// For consistency, this value may be bumped to always be g.batch, +	// but it does not seem to have an effect on anything. +	shape[0] = images.size();  	Ort::AllocatorWithDefaultOptions allocator;  	auto tensor = Ort::Value::CreateTensor<float>( @@ -263,7 +265,7 @@ run(std::vector<Magick::Image> &images, const Config &config,  	auto input_len = tensor.GetTensorTypeAndShapeInfo().GetElementCount();  	auto input_data = tensor.GetTensorMutableData<float>(), pi = input_data; -	for (int64_t i = 0; i < batch; i++) { +	for (int64_t i = 0; i < images.size(); i++) {  		switch (config.shape) {  		case Config::Shape::NCHW:  			pi = image_to_nchw(pi, images.at(i), config.channels); @@ -296,12 +298,12 @@ run(std::vector<Magick::Image> &images, const Config &config,  	auto output_len = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();  	auto output_data = outputs.front().GetTensorData<float>(), po = output_data; -	if (output_len != batch * config.tags.size()) { +	if (output_len != shape[0] * config.tags.size()) {  		fprintf(stderr, "Tags don't match the output\n");  		return;  	} -	for (size_t i = 0; i < batch; i++) { +	for (size_t i = 0; i < images.size(); i++) {  		for (size_t t = 0; t < config.tags.size(); t++) {  			float value = *po++;  			if (config.sigmoid) @@ -616,8 +618,6 @@ infer(Ort::Env &env, const char *path, const std::vector<std::string> &images)  		ctx.output_cv.wait(output_lock,  			[&]{ return ctx.output.size() == g.batch || ctx.done == workers; }); -		// It would be possible to add dummy entries to the batch, -		// so that the model doesn't need to be rebuilt.  		if (!ctx.output.empty()) {  			run(ctx.output, config, session, shape);  			ctx.output.clear(); | 
