本文整理汇总了Java中org.nd4j.linalg.indexing.INDArrayIndex类的典型用法代码示例。如果您正苦于以下问题:Java INDArrayIndex类的具体用法?Java INDArrayIndex怎么用?Java INDArrayIndex使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
INDArrayIndex类属于org.nd4j.linalg.indexing包,在下文中一共展示了INDArrayIndex类的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。
示例1: loadFeaturesFromString
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
/**
* Used post training to convert a String to a features INDArray that can be passed to the network output method
*
* @param reviewContents Contents of the review to vectorize
* @param maxLength Maximum length (if review is longer than this: truncate to maxLength). Use Integer.MAX_VALUE to not nruncate
* @return Features array for the given input String
*/
public INDArray loadFeaturesFromString(String reviewContents, int maxLength){
List<String> tokens = tokenizerFactory.create(reviewContents).getTokens();
List<String> tokensFiltered = new ArrayList<>();
for(String t : tokens ){
if(wordVectors.hasWord(t)) tokensFiltered.add(t);
}
int outputLength = Math.max(maxLength,tokensFiltered.size());
INDArray features = Nd4j.create(1, vectorSize, outputLength);
for( int j=0; j<tokens.size() && j<maxLength; j++ ){
String token = tokens.get(j);
INDArray vector = wordVectors.getWordVectorMatrix(token);
features.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
}
return features;
}
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:26,代码来源:SentimentExampleIterator.java
示例2: testResolvePointVector
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
public void testResolvePointVector() {
INDArray arr = Nd4j.linspace(1, 4, 4);
INDArrayIndex[] getPoint = {NDArrayIndex.point(1)};
INDArrayIndex[] resolved = NDArrayIndex.resolve(arr.shape(), getPoint);
if (getPoint.length == resolved.length)
assertArrayEquals(getPoint, resolved);
else {
assertEquals(2, resolved.length);
assertTrue(resolved[0] instanceof PointIndex);
assertEquals(0, resolved[0].current());
assertTrue(resolved[1] instanceof PointIndex);
assertEquals(1, resolved[1].current());
}
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:17,代码来源:NDArrayIndexResolveTests.java
示例3: testIndexPointInterval
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexPointInterval() {
INDArray zeros = Nd4j.zeros(3, 3, 3);
INDArrayIndex x = NDArrayIndex.point(1);
INDArrayIndex y = NDArrayIndex.interval(1, 2, true);
INDArrayIndex z = NDArrayIndex.point(1);
INDArray value = Nd4j.ones(1, 2);
zeros.put(new INDArrayIndex[] {x, y, z}, value);
String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + " [[0,00,0,00,0,00]\n"
+ " [0,00,1,00,0,00]\n" + " [0,00,1,00,0,00]]\n" + " [[0,00,0,00,0,00]\n"
+ " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";
String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + " [[0.00,0.00,0.00]\n"
+ " [0.00,1.00,0.00]\n" + " [0.00,1.00,0.00]]\n" + " [[0.00,0.00,0.00]\n"
+ " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";
if (!zeros.toString().equals(f2) && !zeros.toString().equals(f1))
assertEquals(f2, zeros.toString());
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:23,代码来源:ShapeResolutionTestsC.java
示例4: testIndexPointAll
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexPointAll() {
INDArray zeros = Nd4j.zeros(3, 3, 3);
INDArrayIndex x = NDArrayIndex.point(1);
INDArrayIndex y = NDArrayIndex.all();
INDArrayIndex z = NDArrayIndex.point(1);
INDArray value = Nd4j.ones(1, 3);
zeros.put(new INDArrayIndex[] {x, y, z}, value);
String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + " [[0,00,1,00,0,00]\n"
+ " [0,00,1,00,0,00]\n" + " [0,00,1,00,0,00]]\n" + " [[0,00,0,00,0,00]\n"
+ " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";
String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + " [[0.00,1.00,0.00]\n"
+ " [0.00,1.00,0.00]\n" + " [0.00,1.00,0.00]]\n" + " [[0.00,0.00,0.00]\n"
+ " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";
if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
assertEquals(f2, zeros.toString());
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:22,代码来源:ShapeResolutionTestsC.java
示例5: testIndexIntervalAll
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexIntervalAll() {
INDArray zeros = Nd4j.zeros(3, 3, 3);
INDArrayIndex x = NDArrayIndex.interval(0, 1, true);
INDArrayIndex y = NDArrayIndex.all();
INDArrayIndex z = NDArrayIndex.interval(1, 2, true);
INDArray value = Nd4j.ones(2, 6);
zeros.put(new INDArrayIndex[] {x, y, z}, value);
String f1 = "[[[0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + " [[0,00,1,00,1,00]\n"
+ " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + " [[0,00,0,00,0,00]\n"
+ " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";
String f2 = "[[[0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + " [[0.00,1.00,1.00]\n"
+ " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + " [[0.00,0.00,0.00]\n"
+ " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";
if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
assertEquals(f2, zeros.toString());
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:22,代码来源:ShapeResolutionTestsC.java
示例6: testIndexPointIntervalAll
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexPointIntervalAll() {
INDArray zeros = Nd4j.zeros(3, 3, 3);
INDArrayIndex x = NDArrayIndex.point(1);
INDArrayIndex y = NDArrayIndex.all();
INDArrayIndex z = NDArrayIndex.interval(1, 2, true);
INDArray value = Nd4j.ones(3, 2);
zeros.put(new INDArrayIndex[] {x, y, z}, value);
String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + " [[0,00,1,00,1,00]\n"
+ " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + " [[0,00,0,00,0,00]\n"
+ " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";
String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + " [[0.00,1.00,1.00]\n"
+ " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + " [[0.00,0.00,0.00]\n"
+ " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";
if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
assertEquals(f2, zeros.toString());
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:22,代码来源:ShapeResolutionTestsC.java
示例7: mergePerOutputMasks2d
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
public static INDArray mergePerOutputMasks2d(int[] outShape, INDArray[] arrays, INDArray[] masks) {
int[] numExamplesPerArr = new int[arrays.length];
for (int i = 0; i < numExamplesPerArr.length; i++) {
numExamplesPerArr[i] = arrays[i].size(0);
}
INDArray outMask = Nd4j.ones(outShape); //Initialize to 'all present' (1s)
int rowsSoFar = 0;
for (int i = 0; i < masks.length; i++) {
int thisRows = numExamplesPerArr[i]; //Mask itself may be null -> all present, but may include multiple examples
if (masks[i] == null) {
continue;
}
outMask.put(new INDArrayIndex[] {NDArrayIndex.interval(rowsSoFar, rowsSoFar + thisRows),
NDArrayIndex.all()}, masks[i]);
rowsSoFar += thisRows;
}
return outMask;
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:22,代码来源:DataSetUtil.java
示例8: toFlattened
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
/**
* Returns a vector with all of the elements in every nd array
* equal to the sum of the lengths of the ndarrays
*
* @param matrices the ndarrays to getFloat a flattened representation of
* @return the flattened ndarray
*/
@Override
public INDArray toFlattened(Collection<INDArray> matrices) {
int length = 0;
for (INDArray m : matrices)
length += m.length();
INDArray ret = Nd4j.create(1, length);
int linearIndex = 0;
for (INDArray d : matrices) {
ret.put(new INDArrayIndex[] {NDArrayIndex.interval(linearIndex, linearIndex + d.length())}, d);
linearIndex += d.length();
}
return ret;
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:23,代码来源:BaseNDArrayFactory.java
示例9: backpropGradient
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
INDArray newEps = Nd4j.create(origOutputShape, 'f');
if(lastTimeStepIdxs == null){
//no mask case
newEps.put(new INDArrayIndex[]{all(), all(), point(origOutputShape[2]-1)}, epsilon);
} else {
INDArrayIndex[] arr = new INDArrayIndex[]{null, all(), null};
//TODO probably possible to optimize this with reshape + scatter ops...
for( int i=0; i<lastTimeStepIdxs.length; i++ ){
arr[0] = point(i);
arr[2] = point(lastTimeStepIdxs[i]);
newEps.put(arr, epsilon.getRow(i));
}
}
return underlying.backpropGradient(newEps);
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:LastTimeStepLayer.java
示例10: preOutput
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
public INDArray preOutput(boolean training) {
INDArray b = getParam(DefaultParamInitializer.BIAS_KEY);
INDArray W = getParam(DefaultParamInitializer.WEIGHT_KEY);
if ( input.columns() != W.columns()) {
throw new DL4JInvalidInputException(
"Input size (" + input.columns() + " columns; shape = " + Arrays.toString(input.shape())
+ ") is invalid: does not match layer input size (layer # inputs = "
+ W.shapeInfoToString() + ") " + layerId());
}
applyDropOutIfNecessary(training);
INDArray ret = Nd4j.zeros(input.rows(),input.columns());
for(int row = 0; row<input.rows();row++){
ret.put(new INDArrayIndex[]{NDArrayIndex.point(row), NDArrayIndex.all()},input.getRow(row).mul(W).addRowVector(b));
}
if (maskArray != null) {
applyMask(ret);
}
return ret;
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:26,代码来源:ElementWiseMultiplicationLayer.java
示例11: doBackward
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
//Allocate the appropriate sized array:
INDArray epsilonsOut = Nd4j.create(fwdPassShape);
if (fwdPassTimeSteps == null) {
//Last time step for all examples
epsilonsOut.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.point(fwdPassShape[2] - 1)}, epsilon);
} else {
//Different time steps were extracted for each example
for (int i = 0; i < fwdPassTimeSteps.length; i++) {
epsilonsOut.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.point(fwdPassTimeSteps[i])}, epsilon.getRow(i));
}
}
return new Pair<>(null, new INDArray[] {epsilonsOut});
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:LastTimeStepVertex.java
示例12: doBackward
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
if (!canDoBackward())
throw new IllegalStateException("Cannot do backward pass: error not set");
INDArray out = Nd4j.zeros(forwardShape);
int start = from * step;
int end = (from + 1) * step;
switch (forwardShape.length) {
case 2:
out.put(new INDArrayIndex[] {NDArrayIndex.interval(start, end), NDArrayIndex.all()}, epsilon);
break;
case 3:
out.put(new INDArrayIndex[] {NDArrayIndex.interval(start, end), NDArrayIndex.all(), NDArrayIndex.all()},
epsilon);
break;
case 4:
out.put(new INDArrayIndex[] {NDArrayIndex.interval(start, end), NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.all()}, epsilon);
break;
default:
throw new RuntimeException("Invalid activation rank"); //Should never happen
}
return new Pair<>(null, new INDArray[] {out});
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:27,代码来源:UnstackVertex.java
示例13: doBackward
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
if (!canDoBackward())
throw new IllegalStateException("Cannot do backward pass: error not set");
INDArray out = Nd4j.zeros(forwardShape);
switch (forwardShape.length) {
case 2:
out.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(from, to, true)}, epsilon);
break;
case 3:
out.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(from, to, true),
NDArrayIndex.all()}, epsilon);
break;
case 4:
out.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(from, to, true),
NDArrayIndex.all(), NDArrayIndex.all()}, epsilon);
break;
default:
throw new RuntimeException("Invalid activation rank"); //Should never happen
}
return new Pair<>(null, new INDArray[] {out});
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:24,代码来源:SubsetVertex.java
示例14: putExample
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
private void putExample(INDArray arr, INDArray singleExample, int exampleIdx) {
switch (arr.rank()) {
case 2:
arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all()}, singleExample);
break;
case 3:
arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all(), NDArrayIndex.all()},
singleExample);
break;
case 4:
arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.all()}, singleExample);
break;
default:
throw new RuntimeException("Unexpected rank: " + arr.rank());
}
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:RecordReaderMultiDataSetIterator.java
示例15: loadSingleSentence
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
/** Generally used post training time to load a single sentence for predictions */
public INDArray loadSingleSentence(String sentence) {
List<String> tokens = tokenizeSentence(sentence);
int[] featuresShape = new int[] {1, 1, 0, 0};
if (sentencesAlongHeight) {
featuresShape[2] = Math.min(maxSentenceLength, tokens.size());
featuresShape[3] = wordVectorSize;
} else {
featuresShape[2] = wordVectorSize;
featuresShape[3] = Math.min(maxSentenceLength, tokens.size());
}
INDArray features = Nd4j.create(featuresShape);
int length = (sentencesAlongHeight ? featuresShape[2] : featuresShape[3]);
for (int i = 0; i < length; i++) {
INDArray vector = getVector(tokens.get(i));
INDArrayIndex[] indices = new INDArrayIndex[4];
indices[0] = NDArrayIndex.point(0);
indices[1] = NDArrayIndex.point(0);
if (sentencesAlongHeight) {
indices[2] = NDArrayIndex.point(i);
indices[3] = NDArrayIndex.all();
} else {
indices[2] = NDArrayIndex.all();
indices[3] = NDArrayIndex.point(i);
}
features.put(indices, vector);
}
return features;
}
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:35,代码来源:CnnSentenceDataSetIterator.java
示例16: updateFilteredBiasCovariates
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
/**
* This method applies the Fourier filter on a given bias covariates matrix, applies the Fourier filter on it,
* partitions the result, and pushes it to compute block(s)
*
* @param biasCovariates any T x D bias covariates matrix
*/
@UpdatesRDD
private void updateFilteredBiasCovariates(@Nonnull final INDArray biasCovariates) {
final INDArray filteredBiasCovariates = Nd4j.create(biasCovariates.shape());
/* instantiate the Fourier filter */
final FourierLinearOperatorNDArray regularizerFourierLinearOperator = createRegularizerFourierLinearOperator();
/* FFT by resolving W_tl on l */
for (int li = 0; li < numLatents; li++) {
final INDArrayIndex[] slice = {NDArrayIndex.all(), NDArrayIndex.point(li)};
filteredBiasCovariates.get(slice).assign(
regularizerFourierLinearOperator.operate(biasCovariates.get(slice)));
}
/* sent the new W to workers */
switch (config.getBiasCovariatesComputeNodeCommunicationPolicy()) {
case BROADCAST_HASH_JOIN:
pushToWorkers(mapINDArrayToBlocks(filteredBiasCovariates),
(W, cb) -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock
.CoverageModelICGCacheNode.F_W_tl, W.get(cb.getTargetSpaceBlock())));
break;
case RDD_JOIN:
joinWithWorkersAndMap(chopINDArrayToBlocks(filteredBiasCovariates),
p -> p._1.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock
.CoverageModelICGCacheNode.F_W_tl, p._2));
break;
}
}
开发者ID:broadinstitute,项目名称:gatk-protected,代码行数:36,代码来源:CoverageModelEMWorkspace.java
示例17: sgeqrf
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Override
public void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) {
INDArray tau = Nd4j.create( N ) ;
int status = LAPACKE_sgeqrf(getColumnOrder(A), M, N,
(FloatPointer)A.data().addressPointer(), getLda(A),
(FloatPointer)tau.data().addressPointer()
);
if( status != 0 ) {
throw new BlasException( "Failed to execute sgeqrf", status ) ;
}
// Copy R ( upper part of Q ) into result
if( R != null ) {
R.assign( A.get( NDArrayIndex.interval( 0, A.columns() ), NDArrayIndex.all() ) ) ;
INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ;
for( int i=1 ; i<Math.min( A.rows(), A.columns() ) ; i++ ) {
ix[0] = NDArrayIndex.point( i ) ;
ix[1] = NDArrayIndex.interval( 0, i ) ;
R.put(ix, 0) ;
}
}
status = LAPACKE_sorgqr( getColumnOrder(A), M, N, N,
(FloatPointer)A.data().addressPointer(), getLda(A),
(FloatPointer)tau.data().addressPointer()
);
if( status != 0 ) {
throw new BlasException( "Failed to execute sorgqr", status ) ;
}
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:33,代码来源:CpuLapack.java
示例18: dgeqrf
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Override
public void dgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) {
INDArray tau = Nd4j.create( N ) ;
int status = LAPACKE_dgeqrf(getColumnOrder(A), M, N,
(DoublePointer)A.data().addressPointer(), getLda(A),
(DoublePointer)tau.data().addressPointer()
);
if( status != 0 ) {
throw new BlasException( "Failed to execute dgeqrf", status ) ;
}
// Copy R ( upper part of Q ) into result
if( R != null ) {
R.assign( A.get( NDArrayIndex.interval( 0, A.columns() ), NDArrayIndex.all() ) ) ;
INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ;
for( int i=1 ; i<Math.min( A.rows(), A.columns() ) ; i++ ) {
ix[0] = NDArrayIndex.point( i ) ;
ix[1] = NDArrayIndex.interval( 0, i ) ;
R.put(ix, 0) ;
}
}
status = LAPACKE_dorgqr( getColumnOrder(A), M, N, N,
(DoublePointer)A.data().addressPointer(), getLda(A),
(DoublePointer)tau.data().addressPointer()
);
if( status != 0 ) {
throw new BlasException( "Failed to execute dorgqr", status ) ;
}
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:33,代码来源:CpuLapack.java
示例19: testResolvePoint
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
public void testResolvePoint() {
INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2);
INDArrayIndex[] test = NDArrayIndex.resolve(arr.shape(), NDArrayIndex.point(1));
INDArrayIndex[] assertion = {NDArrayIndex.point(1), NDArrayIndex.all()};
assertArrayEquals(assertion, test);
INDArrayIndex[] allAssertion = {NDArrayIndex.all(), NDArrayIndex.all()};
assertArrayEquals(allAssertion, NDArrayIndex.resolve(arr.shape(), NDArrayIndex.all()));
INDArrayIndex[] allAndOne = new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(1)};
assertArrayEquals(allAndOne, NDArrayIndex.resolve(arr.shape(), allAndOne));
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:14,代码来源:NDArrayIndexResolveTests.java
示例20: testNegativeBounds
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
public void testNegativeBounds() {
INDArray arr = Nd4j.linspace(1,10,10).reshape(2,5);
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
INDArray get = arr.get(NDArrayIndex.all(),interval);
INDArray assertion = Nd4j.create(new double[][]{
{1,2,3},
{6,7,8}
});
assertEquals(assertion,get);
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:12,代码来源:IndexingTestsC.java
注:本文中的org.nd4j.linalg.indexing.INDArrayIndex类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论