How can I feed a sparse placeholder in a TensorFlow model from Java

Tobias Ott :

I'm trying to calculate the best match for a given address with the kNN algorithm in TensorFlow, which works pretty good, but when I'm trying to export the model and use it in our Java Environment I got stuck on how to feed the sparse placholders from Java.

Here is a pretty much stripped down version of the python part, which returns the smallest distance between the test name and the best reference name. So far this work's as expected. When I export the model and import it in my Java program it always returns the same value (distance of the placeholders default). I asume, that the python function sparse_from_word_vec(word_vec) isn't in the model, which would totally make sense to me, but then how should i make this sparse tensor? My input is a single string and I need to create a fitting sparse tensor (value) to calculate the distance. I also searched for a way to generate the sparse tensor on the Java side, but without success.

import tensorflow as tf
import pandas as pd

d = {'NAME': ['max mustermann', 
              'erika musterfrau', 
              'joseph haydn', 
              'johann sebastian bach', 
              'wolfgang amadeus mozart']}

df = pd.DataFrame(data=d)  

input_name = tf.placeholder_with_default('max musterman',(), name='input_name')
output_dist = tf.placeholder(tf.float32, (), name='output_dist')

test_name = tf.sparse_placeholder(dtype=tf.string)
ref_names = tf.sparse_placeholder(dtype=tf.string)

output_dist = tf.edit_distance(test_name, ref_names, normalize=True)

def sparse_from_word_vec(word_vec):
    num_words = len(word_vec)
    indices = [[xi, 0, yi] for xi,x in enumerate(word_vec) for yi,y in enumerate(x)]
    chars = list(''.join(word_vec))
    return(tf.SparseTensorValue(indices, chars, [num_words,1,1]))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    t_data_names=tf.constant(df['NAME'])
    reference_names = [el.decode('UTF-8') for el in (t_data_names.eval())]

    sparse_ref_names = sparse_from_word_vec(reference_names)
    sparse_test_name = sparse_from_word_vec([str(input_name.eval().decode('utf-8'))]*5)

    feeddict={test_name: sparse_test_name,
              ref_names: sparse_ref_names, 
              }    

    output_dist = sess.run(output_dist, feed_dict=feeddict)
    output_dist = tf.reduce_min(output_dist, 0)
    print(output_dist.eval())

    tf.saved_model.simple_save(sess,
                               "model-simple",
                               inputs={"input_name": input_name},
                               outputs={"output_dist": output_dist})

And here is my Java method:

public void run(ApplicationArguments args) throws Exception {
  log.info("Loading model...");

  SavedModelBundle savedModelBundle = SavedModelBundle.load("/model", "serve");

  byte[] test_name = "Max Mustermann".toLowerCase().getBytes("UTF-8");


  List<Tensor<?>> output = savedModelBundle.session().runner()
      .feed("input_name", Tensor.<String>create(test_names))
      .fetch("output_dist")
      .run();

  System.out.printl("Nearest distance: " + output.get(0).floatValue());

}
McAngus :

I was able to get your example working. I have a couple of comments on your python code before diving in.

You use the variable output_dist for 3 different value types throughout the code. I'm not a python expert, but I think it's bad practice. You also never actually use the input_name placeholder, except for exporting it as an input. Last one is that tf.saved_model.simple_save is deprecated, and you should use the tf.saved_model.Builder instead.

Now for the solution.

Looking at the libtensorflow jar file using the command jar tvf libtensorflow-x.x.x.jar (thanks to this post), you can see that there are no useful bindings for creating a sparse tensor (maybe make a feature request?). So we have to change the input to a dense tensor, then add operations to the graph to convert it to sparse. In your original code the sparse conversion was on the python side which means that the loaded graph in java wouldn't have any ops for it.

Here is the new python code:

import tensorflow as tf
import pandas as pd

def model():
    #use dense tensors then convert to sparse for edit_distance
    test_name = tf.placeholder(shape=(None, None), dtype=tf.string, name="test_name")
    ref_names = tf.placeholder(shape=(None, None), dtype=tf.string, name="ref_names")

    #Java Does not play well with the empty character so use "/" instead
    test_name_sparse = tf.contrib.layers.dense_to_sparse(test_name, "/")
    ref_names_sparse = tf.contrib.layers.dense_to_sparse(ref_names, "/")

    output_dist = tf.edit_distance(test_name_sparse, ref_names_sparse, normalize=True)

    #output the index to the closest ref name
    min_idx = tf.argmin(output_dist)

    return test_name, ref_names, min_idx

#Python code to be replicated in Java
def pad_string(s, max_len):
    return s + ["/"] * (max_len - len(s))

d = {'NAME': ['joseph haydn', 
              'max mustermann', 
              'erika musterfrau', 
              'johann sebastian bach', 
              'wolfgang amadeus mozart']}

df = pd.DataFrame(data=d)  
input_name = 'max musterman'

#pad dense tensor input
max_len = max([len(n) for n in df['NAME']])

test_input = [list(input_name)]*len(df['NAME'])
#no need to pad, all same length
ref_input = list(map(lambda x: pad_string(x, max_len), [list(n) for n in df['NAME']]))


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    test_name, ref_names, min_idx = model()

    #run a test to make sure the model works
    feeddict = {test_name: test_input,
                ref_names: ref_input,
            }
    out = sess.run(min_idx, feed_dict=feeddict)
    print("test output:", out)

    #save the model with the new Builder API
    signature_def_map= {
    "predict": tf.saved_model.signature_def_utils.predict_signature_def(
        inputs= {"test_name": test_name, "ref_names": ref_names},
        outputs= {"min_idx": min_idx})
    }

    builder = tf.saved_model.Builder("model")
    builder.add_meta_graph_and_variables(sess, ["serve"], signature_def_map=signature_def_map)
    builder.save()

And here is the java to load and run it. There is probably a lot of room for improvement here (java isn't my main language), but it gives you the idea.

import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.SavedModelBundle;

import java.util.ArrayList;
import java.util.List;
import java.util.Arrays;

public class Test {
    public static byte[][] makeTensor(String s, int padding) throws Exception
    {
        int len = s.length();
        int extra = padding - len;

        byte[][] ret = new byte[len + extra][];
        for (int i = 0; i < len; i++) {
            String cur = "" + s.charAt(i);
            byte[] cur_b = cur.getBytes("UTF-8");
            ret[i] = cur_b;
        }

        for (int i = 0; i < extra; i++) {
            byte[] cur = "/".getBytes("UTF-8");
            ret[len + i] = cur;
        }

        return ret;
    }
    public static byte[][][] makeTensor(List<String> l, int padding) throws Exception
    {
        byte[][][] ret = new byte[l.size()][][];
        for (int i = 0; i < l.size(); i++) {
            ret[i] = makeTensor(l.get(i), padding);
        }

        return ret;
    }
    public static void main(String[] args) throws Exception {
        System.out.println("Loading model...");

        SavedModelBundle savedModelBundle = SavedModelBundle.load("model", "serve");


        List<String> str_test_name = Arrays.asList("Max Mustermann",
            "Max Mustermann",
            "Max Mustermann",
            "Max Mustermann",
            "Max Mustermann");
        List<String> names = Arrays.asList("joseph haydn",
            "max mustermann",
            "erika musterfrau",
            "johann sebastian bach",
            "wolfgang amadeus mozart");

        //get the max length for each array
        int pad1 = str_test_name.get(0).length();
        int pad2 = 0;
        for (String var : names) {
            if(var.length() > pad2)
                pad2 = var.length();
        }


        byte[][][] test_name = makeTensor(str_test_name, pad1);
        byte[][][] ref_names = makeTensor(names, pad2);

        //use a with block so the close method is called
        try(Tensor t_test_name = Tensor.<String>create(test_name))
        {
            try (Tensor t_ref_names = Tensor.<String>create(ref_names))
            {
                List<Tensor<?>> output = savedModelBundle.session().runner()
                    .feed("test_name", t_test_name)
                    .feed("ref_names", t_ref_names)
                    .fetch("ArgMin")
                    .run();

                System.out.println("Nearest distance: " + output.get(0).longValue());
            }
        }
    }
}

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=158746&siteId=1