Skip to content
Next Next commit
DATAMONGO-586 initial commit
  • Loading branch information
ttrelle committed Jan 30, 2013
commit 5a3263091d5cd6b1129cc2ce4e06323857034326
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import java.util.List;
import java.util.Set;

import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.core.convert.MongoConverter;
import org.springframework.data.mongodb.core.geo.GeoResult;
import org.springframework.data.mongodb.core.geo.GeoResults;
Expand Down Expand Up @@ -301,6 +303,16 @@ public interface MongoOperations {
*/
<T> GroupByResults<T> group(Criteria criteria, String inputCollectionName, GroupBy groupBy, Class<T> entityClass);

/**
* Execute an aggregation operation. The raw results will be mapped to the given entity class.
*
* @param inputCollectionName the collection there the aggregation operation will read from.
* @param pipeline The pipeline holding the aggregation operations.
* @param entityClass The parameterized type of the returned list.
* @return The results of the aggregation operation.
*/
<T> AggregationResults<T> aggregate(String inputCollectionName, AggregationPipeline pipeline, Class<T> entityClass);

/**
* Execute a map-reduce operation. The map-reduce operation will be formed with an output type of INLINE
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/
package org.springframework.data.mongodb.core;

import static org.springframework.data.mongodb.core.query.Criteria.*;
import static org.springframework.data.mongodb.core.query.SerializationUtils.*;
import static org.springframework.data.mongodb.core.query.Criteria.where;
import static org.springframework.data.mongodb.core.query.SerializationUtils.serializeToJsonSafely;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -54,6 +54,8 @@
import org.springframework.data.mapping.model.BeanWrapper;
import org.springframework.data.mapping.model.MappingException;
import org.springframework.data.mongodb.MongoDbFactory;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
import org.springframework.data.mongodb.core.convert.MongoConverter;
import org.springframework.data.mongodb.core.convert.MongoWriter;
Expand Down Expand Up @@ -1171,6 +1173,31 @@ public <T> GroupByResults<T> group(Criteria criteria, String inputCollectionName

}

public <T> AggregationResults<T> aggregate(String inputCollectionName, AggregationPipeline pipeline, Class<T> entityClass) {
Assert.notNull(inputCollectionName, "Collection name is missing");
Assert.notNull(pipeline, "Aggregation pipeline is missing");
Assert.notNull(entityClass, "Entity class is missing");

// prepare command
DBObject command = new BasicDBObject("aggregate", inputCollectionName );
command.put( "pipeline", pipeline.getOps() );

// execute command
CommandResult commandResult = executeCommand(command);
handleCommandError(commandResult, command);

// map results
@SuppressWarnings("unchecked")
Iterable<DBObject> resultSet = (Iterable<DBObject>) commandResult.get("result");
List<T> mappedResults = new ArrayList<T>();
DbObjectCallback<T> callback = new ReadDbObjectCallback<T>(mongoConverter, entityClass);
for (DBObject dbObject : resultSet) {
mappedResults.add(callback.doWith(dbObject));
}

return new AggregationResults<T>(mappedResults, commandResult);
}

protected String replaceWithResourceIfNecessary(String function) {

String func = function;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright 2011-2012 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core.aggregation;

import java.util.ArrayList;
import java.util.List;

import org.springframework.util.Assert;

import com.mongodb.DBObject;
import com.mongodb.util.JSON;
import com.mongodb.util.JSONParseException;

/**
* Holds the operation of an aggregation pipeline.
*
* @author Tobias Trelle
*/
public class AggregationPipeline {

private List<DBObject> ops = new ArrayList<DBObject>();

public AggregationPipeline(String... operations) {
Assert.notNull(operations, "Aggregation pipeline operations are missing");

if (operations.length > 0) {
for (int i = 0; i < operations.length; i++) {
ops.add( parseJson(operations[i]) );
}
}

}

public List<DBObject> getOps() {
return ops;
}

private DBObject parseJson(String json) {
try {
return (DBObject) JSON.parse(json);
} catch (JSONParseException e) {
throw new IllegalArgumentException("Not a valid JSON document: " + json, e);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2011-2012 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core.aggregation;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.springframework.util.Assert;

import com.mongodb.DBObject;

/**
* Collects the results of executing an aggregation operation.
*
* @author Tobias Trelle
*
* @param <T> The class in which the results are mapped onto.
*/
public class AggregationResults<T> implements Iterable<T> {

private final List<T> mappedResults;
private final DBObject rawResults;

private String serverUsed;

public AggregationResults(List<T> mappedResults, DBObject rawResults) {
Assert.notNull(mappedResults);
Assert.notNull(rawResults);
this.mappedResults = mappedResults;
this.rawResults = rawResults;
parseServerUsed();
}

public List<T> getAggregationResult() {
List<T> result = new ArrayList<T>();
Iterator<T> it = iterator();

while (it.hasNext()) {
result.add(it.next());
}

return result;
}

@Override
public Iterator<T> iterator() {
return mappedResults.iterator();
}

public String getServerUsed() {
return serverUsed;
}

private void parseServerUsed() {
// "serverUsed" : "127.0.0.1:27017"
Object object = rawResults.get("serverUsed");
if (object instanceof String) {
serverUsed = (String) object;
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright 2011-2012 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core.aggregation;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.junit.Assert.assertThat;

import java.util.ArrayList;
import java.util.List;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;

import com.mongodb.BasicDBObject;
import com.mongodb.DBCollection;
import com.mongodb.DBObject;

/**
* Tests for {@link MongoTemplate#aggregate(String, AggregationPipeline, Class)}.
*
* @author Tobias Trelle
*/
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration("classpath:infrastructure.xml")
public class AggregationTests {

private static final String INPUT_COLLECTION = "aggregation_test_collection";

@Autowired
MongoTemplate mongoTemplate;

@Before
public void setUp() {
cleanDb();
}

@After
public void cleanUp() {
cleanDb();
}

@Test(expected = IllegalArgumentException.class)
public void shouldHandleMissingInputCollection() {
mongoTemplate.aggregate(null, new AggregationPipeline((String[]) null), TagCount.class);
}

@Test(expected = IllegalArgumentException.class)
public void shouldHandleMissingAggregationPipeline() {
mongoTemplate.aggregate(INPUT_COLLECTION, null, TagCount.class);
}

@Test(expected = IllegalArgumentException.class)
public void shouldHandleMissingEntityClass() {
mongoTemplate.aggregate(INPUT_COLLECTION, new AggregationPipeline((String[]) null), null);
}

@Test(expected = IllegalArgumentException.class)
public void shouldDetectIllegaAggregationOperation() {
// given
AggregationPipeline pipeline = new AggregationPipeline("{ foo bar");

// when
mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class);

// then: throw expected exception
}

@Test
public void shouldAggregate() {
// given
createDocuments();
AggregationPipeline pipeline = new AggregationPipeline("{$project:{_id:0,tags:1}}", "{$unwind: \"$tags\"}",
"{$group:{_id:\"$tags\", n:{$sum:1}}}", "{$project:{tag: \"$_id\", n:1, _id:0}}", "{$sort:{n:-1}}");

// when
AggregationResults<TagCount> results = mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class);

// then
assertThat(results, notNullValue());
assertThat(results.getServerUsed(), is("/127.0.0.1:27017"));

List<TagCount> tagCount = results.getAggregationResult();
assertThat(tagCount, notNullValue());
assertThat(tagCount.size(), is(3));
assertTagCount("spring", 3, tagCount.get(0));
assertTagCount("mongodb", 2, tagCount.get(1));
assertTagCount("nosql", 1, tagCount.get(2));
}

@Test(expected = InvalidDataAccessApiUsageException.class)
public void shouldDetectIllegalAggregationOperation() {
// given
createDocuments();
AggregationPipeline pipeline = new AggregationPipeline("{$foobar:{_id:0,tags:1}}");

// when
mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class);

// then: throw expected exception
}

@Test
public void shouldAggregateEmptyCollection() {
// given
AggregationPipeline pipeline = new AggregationPipeline("{$project:{_id:0,tags:1}}", "{$unwind: \"$tags\"}",
"{$group:{_id:\"$tags\", n:{$sum:1}}}", "{$project:{tag: \"$_id\", n:1, _id:0}}", "{$sort:{n:-1}}");

// when
AggregationResults<TagCount> results = mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class);

// then
assertThat(results, notNullValue());
assertThat(results.getServerUsed(), is("/127.0.0.1:27017"));

List<TagCount> tagCount = results.getAggregationResult();
assertThat(tagCount, notNullValue());
assertThat(tagCount.size(), is(0));
}

protected void cleanDb() {
mongoTemplate.dropCollection(INPUT_COLLECTION);
}

private void createDocuments() {
DBCollection coll = mongoTemplate.getCollection(INPUT_COLLECTION);
coll.insert(createDocument("Doc1", "spring", "mongodb", "nosql"));
coll.insert(createDocument("Doc2", "spring", "mongodb"));
coll.insert(createDocument("Doc3", "spring"));
}

private DBObject createDocument(String title, String... tags) {
DBObject doc = new BasicDBObject("title", title);
List<String> tagList = new ArrayList<String>();
for (String tag : tags) {
tagList.add(tag);
}
doc.put("tags", tagList);

return doc;
}

private void assertTagCount(String tag, int n, TagCount tagCount) {
assertThat(tagCount.getTag(), is(tag));
assertThat(tagCount.getN(), is(n));
}

}
Loading