function.proto 3.11 KB
Newer Older
wester committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
syntax = "proto3";

package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "FunctionProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";
import "op_def.proto";

// A library is a set of named functions.
message FunctionDefLibrary {
  repeated FunctionDef function = 1;
  repeated GradientDef gradient = 2;
}

// A function can be instantiated when the runtime can bind every attr
// with a value. When a GraphDef has a call to a function, it must
// have binding for every attr defined in the signature.
//
// TODO(zhifengc):
//   * device spec, etc.
message FunctionDef {
  // The definition of the function's name, arguments, return values,
  // attrs etc.
  OpDef signature = 1;

  // The body of the function.
  repeated Node node = 2;  // function.node.ret[*] are unique.

  // A node is a multi-value assignment:
  //   (ret[0], ret[1], ...) = func(arg[0], arg[1], ...)
  //
  // By convention, "func" is resolved by consulting with a user-defined
  // library first. If not resolved, "func" is assumed to be a builtin op.
  message Node {
    // This node produces multiple outputs. They are named ret[0],
    // ret[1], ..., etc.
    //
    // REQUIRES: function.node.ret[*] are unique across all nodes.
    // REQUIRES: ret.size == func/op def's number of output args.
    repeated string ret = 1;

    // The op/function name.
    string op = 2;

    // Arguments passed to this func/op.
    //
    // arg[i] must be either one of
    // function.signature.input_args[*].name or one of
    // function.node[*].ret[*].
    //
    // REQUIRES: arg.size == func/op def's number of input args.
    repeated string arg = 3;

    // Control dependencies.
    //
    // dep[i] must be one of function.node[*].ret[*] or one of
    // function.signature.input_args[*].name.
    repeated string dep = 4;

    // Attrs.
    //
    // 'attr' maps names defined by 'func's attr defs to attr values.
    // attr values may have placeholders which are substituted
    // recursively by concrete values when this node is instantiated.
    // These placeholders must name an attr listed in the FunctionDef's
    // signature.
    map<string, AttrValue> attr = 5;
  }
}

// GradientDef defines the gradient function of a function defined in
// a function library.
//
// A gradient function g (specified by gradient_func) for a function f
// (specified by function_name) must follow the following:
//
// The function 'f' must be a numerical function which takes N inputs
// and produces M outputs. Its gradient function 'g', which is a
// function taking N + M inputs and produces N outputs.
//
// I.e. if we have
//    (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
// then, g is
//    (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
//                                      dL/dy1, dL/dy2, ..., dL/dy_M),
// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
// loss function). dL/dx_i is the partial derivative of L with respect
// to x_i.
message GradientDef {
  string function_name = 1;  // The function name.
  string gradient_func = 2;  // The gradient function's name.
}