syntax = "proto3"; package tensorflow; option cc_enable_arenas = true; option java_outer_classname = "FunctionProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.framework"; import "attr_value.proto"; import "op_def.proto"; // A library is a set of named functions. message FunctionDefLibrary { repeated FunctionDef function = 1; repeated GradientDef gradient = 2; } // A function can be instantiated when the runtime can bind every attr // with a value. When a GraphDef has a call to a function, it must // have binding for every attr defined in the signature. // // TODO(zhifengc): // * device spec, etc. message FunctionDef { // The definition of the function's name, arguments, return values, // attrs etc. OpDef signature = 1; // The body of the function. repeated Node node = 2; // function.node.ret[*] are unique. // A node is a multi-value assignment: // (ret[0], ret[1], ...) = func(arg[0], arg[1], ...) // // By convention, "func" is resolved by consulting with a user-defined // library first. If not resolved, "func" is assumed to be a builtin op. message Node { // This node produces multiple outputs. They are named ret[0], // ret[1], ..., etc. // // REQUIRES: function.node.ret[*] are unique across all nodes. // REQUIRES: ret.size == func/op def's number of output args. repeated string ret = 1; // The op/function name. string op = 2; // Arguments passed to this func/op. // // arg[i] must be either one of // function.signature.input_args[*].name or one of // function.node[*].ret[*]. // // REQUIRES: arg.size == func/op def's number of input args. repeated string arg = 3; // Control dependencies. // // dep[i] must be one of function.node[*].ret[*] or one of // function.signature.input_args[*].name. repeated string dep = 4; // Attrs. // // 'attr' maps names defined by 'func's attr defs to attr values. // attr values may have placeholders which are substituted // recursively by concrete values when this node is instantiated. // These placeholders must name an attr listed in the FunctionDef's // signature. map<string, AttrValue> attr = 5; } } // GradientDef defines the gradient function of a function defined in // a function library. // // A gradient function g (specified by gradient_func) for a function f // (specified by function_name) must follow the following: // // The function 'f' must be a numerical function which takes N inputs // and produces M outputs. Its gradient function 'g', which is a // function taking N + M inputs and produces N outputs. // // I.e. if we have // (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), // then, g is // (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, // dL/dy1, dL/dy2, ..., dL/dy_M), // where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the // loss function). dL/dx_i is the partial derivative of L with respect // to x_i. message GradientDef { string function_name = 1; // The function name. string gradient_func = 2; // The gradient function's name. }