import numpy as np
import gym
from gym.spaces import Box
import ray
from ray import tune
from ray.rllib.utils import try_import_tf
import ray.rllib.agents.ddpg as ddpg
from ray.tune.logger import pretty_print
tf = try_import_tf()
# gym environment adapter
class SimpleSupplyChain(gym.Env):
def __init__(self, config):
self.reset()
self.action_space = Box(low=0.0, high=1000.0, shape=(self.supply_chain.retail_store_num + 1, ), dtype=np.float32)
self.observation_space = Box(-1000000.0, 10000000, shape=(len(self.supply_chain.initial_state().to_array()), ), dtype=np.float32)
def reset(self):
self.supply_chain = SupplyChainEnvironment()
self.state = self.supply_chain.initial_state()
return self.state.to_array()
def step(self, action):
action_obj = Action(self.supply_chain.retail_store_num)
action_obj.production_level = action[0]
action_obj.shippings_to_retail_stores = action[1:]
self.state, reward, done = self.supply_chain.step(self.state, action_obj)
return self.state.to_array(), reward, done, {}
ray.shutdown()
ray.init()
def train_ddpg():
config = ddpg.DEFAULT_CONFIG.copy()
config["log_level"] = "WARN"
config["actor_hiddens"] = [512, 512]
config["critic_hiddens"] = [512, 512]
config["gamma"] = 0.95
config["timesteps_per_iteration"] = 1000
config["target_network_update_freq"] = 5
config["buffer_size"] = 10000
# config['actor_lr']=1e-6
# config['critic_lr']=1e-6
print(config)
import json
try:
import cPickle as pickle
except ImportError: # Python 3.x
import pickle
with open('config.p', 'wb') as fp:
pickle.dump(config, fp, protocol=pickle.HIGHEST_PROTOCOL)
trainer = ddpg.DDPGTrainer(config=config, env=SimpleSupplyChain)
for i in range(5):
result = trainer.train()
print(pretty_print(result))
checkpoint = trainer.save('./')
print("Checkpoint saved at", checkpoint)
train_ddpg()
When I run the above code I am getting this error.
Cannot cast array data from dtype('O') to dtype('float32') according to the rule 'same_kind'
I am using DDPG to find the solution but it seems that there is something wrong with the data type. The data comes from pandas and I changed the datatype to float64. I checked and tried everything i could but the error still appeared. Could someone help me please?
question from:
https://stackoverflow.com/questions/65932091/python-error-cannot-cast-array-data-from-dtypeo-to-dtypefloat32-accordi 与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…