Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
190 views
in Technique[技术] by (71.8m points)

python - Get elements of tuple from tf.data.Dataset

I am building a neural network with two input nodes that are connected to an embedding layer each.

I have created a tf.data.Dataset with a tuple as input for the model.

How can I split the tensors in the tuple to forward the first tensor (scalar) to embedding layer 1 and the second tensor (array) to embedding layer 2 in a custom forward pass?

I provided an example below.

Thanks in advance.

import pandas as pd
import tensorflow as tf

from random import randrange

df = pd.DataFrame(columns=['cust', 'items'])

for i in range(100):

    cust = randrange(100)
    items = [randrange(100), randrange(100), randrange(100), randrange(100), randrange(100)]

    df = df.append({"cust": cust, "items": items}, ignore_index=True)

    i += 1

dataset = tf.data.Dataset.from_tensor_slices((df["cust"], df["items"]))

dataset_batches = dataset.batch(10)

# custom forward pass
def call(self, inputs):
    x = inputs[0]  # This does not work.
    y = inputs[1]  # This does not work.

    x = self.cust(x)  # input layer 1
    y = self.items(y)  # input layer 2

    x = self.emb_cust(x)  # embedding layer 1
    y = self.emb_items(y)  # embedding layer 2

    z = self.pre_calc([x, y])  # lambda layer

    return z
question from:https://stackoverflow.com/questions/65890282/get-elements-of-tuple-from-tf-data-dataset

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

For somebody with a similar question:

My solution above is actually correct, so you can extract the elements of the tuple from the current batch and put it forward as a list for the forward pass.

def run_model(self, epochs, dataset_batches):

    for epoch in range(epochs):
            
            for step, (cust, items) in enumerate(dataset_batches):
            
                   # execute forward pass
                   y_pred = self([cust, items], training=True)
                   ...

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

1.4m articles

1.4m replys

5 comments

57.0k users

...