classOperation:def__init__(self,input_nodes=[]):self.input_nodes=input_nodes# Initialize list of consumers self.consumers=[]# Append this operation to the list of consumers of all input nodesforinput_nodeininput_nodes:input_node.consumers.append(self)# Append this operation to the list of operations in the currently active default graph_default_graph.operations.append(self)defcompute(self):pass
compute方法是需要每个operation节点子类去实现。
1
2
3
4
5
6
7
8
9
10
11
# Addition Operation节点示例classadd(Operation):"""Returns x + y element-wise.
"""def__init__(self,x,y):super().__init__([x,y])defcompute(self,x_value,y_value):self.inputs=[x_value,y_value]returnx_value+y_value
1
2
3
4
5
6
7
8
9
10
11
# Matrix Multiplicaiton Operation节点示例classmatmul(Operation):"""Multiplies matrix a by matrix b, producing a * b.
"""def__init__(self,a,b):super().__init__([a,b])defcompute(self,a_value,b_value):self.inputs=[a_value,b_value]returna_value.dot(b_value)
classplaceholder:"""Represents a placeholder node that has to be provided with a value
when computing the output of a computational graph
"""def__init__(self):self.consumers=[]_default_graph.placeholders.append(self)
classVariable:"""Represents a variable (i.e. an intrinsic, changeable parameter of a computational graph).
"""def__init__(self,initial_value=None):self.value=initial_valueself.consumers=[]_default_graph.variables.append(self)
importnumpyasnpclassSession:"""Represents a particular execution of a computational graph.
"""defrun(self,operation,feed_dict={}):nodes_postorder=traverse_postorder(operation)# Iterate all nodes to determine their valuefornodeinnodes_postorder:iftype(node)==placeholder:# Set the node value to the placeholder value from feed_dictnode.output=feed_dict[node]eliftype(node)==Variable:# Set the node value to the variable's value attributenode.output=node.valueelse:# Operation# Get the input values for this operation from node_valuesnode.inputs=[input_node.outputforinput_nodeinnode.input_nodes]# Compute the output of this operationnode.output=node.compute(*node.inputs)# Convert lists to numpy arraysiftype(node.output)==list:node.output=np.array(node.output)# Return the requested node valuereturnoperation.outputdeftraverse_postorder(operation):nodes_postorder=[]defrecurse(node):ifisinstance(node,Operation):forinput_nodeinnode.input_nodes:recurse(input_node)nodes_postorder.append(node)recurse(operation)returnnodes_postorder