Skip to content
Prev Previous commit
Use projecting read callback to allow interface projections.
Along the lines fix entity operations proxy handling by reading the underlying map instead of inspecting the proxy interface. Also make sure to map potential raw fields back to the according property.
  • Loading branch information
christophstrobl committed Mar 17, 2023
commit 135634e9edb1651a43867f1bb76bda6f2888280c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import org.springframework.data.mapping.MappingException;
import org.springframework.data.mapping.PersistentEntity;
import org.springframework.data.mapping.PersistentPropertyAccessor;
import org.springframework.data.mapping.PersistentPropertyPath;
import org.springframework.data.mapping.PropertyPath;
import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.mapping.model.ConvertingPropertyAccessor;
import org.springframework.data.mongodb.core.CollectionOptions.TimeSeriesOptions;
Expand All @@ -50,6 +52,7 @@
import org.springframework.data.projection.EntityProjection;
import org.springframework.data.projection.EntityProjectionIntrospector;
import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.projection.TargetAware;
import org.springframework.data.util.Optionals;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -117,12 +120,16 @@ <T> Entity<T> forEntity(T entity) {

Assert.notNull(entity, "Bean must not be null");

if (entity instanceof TargetAware targetAware) {
return new SimpleMappedEntity((Map<String, Object>) targetAware.getTarget(), this);
}

if (entity instanceof String) {
return new UnmappedEntity(parse(entity.toString()));
return new UnmappedEntity(parse(entity.toString()), this);
}

if (entity instanceof Map) {
return new SimpleMappedEntity((Map<String, Object>) entity);
return new SimpleMappedEntity((Map<String, Object>) entity, this);
}

return MappedEntity.of(entity, context, this);
Expand All @@ -142,11 +149,11 @@ <T> AdaptibleEntity<T> forEntity(T entity, ConversionService conversionService)
Assert.notNull(conversionService, "ConversionService must not be null");

if (entity instanceof String) {
return new UnmappedEntity(parse(entity.toString()));
return new UnmappedEntity(parse(entity.toString()), this);
}

if (entity instanceof Map) {
return new SimpleMappedEntity((Map<String, Object>) entity);
return new SimpleMappedEntity((Map<String, Object>) entity, this);
}

return AdaptibleMappedEntity.of(entity, context, conversionService, this);
Expand Down Expand Up @@ -287,7 +294,8 @@ public <T> TypedOperations<T> forType(@Nullable Class<T> entityClass) {
*/
public <M, D> EntityProjection<M, D> introspectProjection(Class<M> resultType, Class<D> entityType) {

if (!queryMapper.getMappingContext().hasPersistentEntityFor(entityType)) {
MongoPersistentEntity<?> persistentEntity = queryMapper.getMappingContext().getPersistentEntity(entityType);
if (persistentEntity == null && !resultType.isInterface() || ClassUtils.isAssignable(Document.class, resultType)) {
return (EntityProjection) EntityProjection.nonProjecting(resultType);
}
return introspector.introspect(resultType, entityType);
Expand Down Expand Up @@ -369,6 +377,7 @@ private Document getMappedValidator(Validator validator, Class<?> domainType) {
* A representation of information about an entity.
*
* @author Oliver Gierke
* @author Christoph Strobl
* @since 2.1
*/
interface Entity<T> {
Expand Down Expand Up @@ -471,10 +480,10 @@ default boolean isVersionedEntity() {
/**
* @param sortObject
* @return
* @since 3.1
* @since 4.1
* @throws IllegalStateException if a sort key yields {@literal null}.
*/
Map<String, Object> extractKeys(Document sortObject);
Map<String, Object> extractKeys(Document sortObject, Class<?> sourceType);

}

Expand Down Expand Up @@ -523,9 +532,11 @@ interface AdaptibleEntity<T> extends Entity<T> {
private static class UnmappedEntity<T extends Map<String, Object>> implements AdaptibleEntity<T> {

private final T map;
private final EntityOperations entityOperations;

protected UnmappedEntity(T map) {
protected UnmappedEntity(T map, EntityOperations entityOperations) {
this.map = map;
this.entityOperations = entityOperations;
}

@Override
Expand Down Expand Up @@ -596,13 +607,19 @@ public boolean isNew() {
}

@Override
public Map<String, Object> extractKeys(Document sortObject) {
public Map<String, Object> extractKeys(Document sortObject, Class<?> sourceType) {

Map<String, Object> keyset = new LinkedHashMap<>();
keyset.put(ID_FIELD, getId());
MongoPersistentEntity<?> sourceEntity = entityOperations.context.getPersistentEntity(sourceType);
if (sourceEntity != null && sourceEntity.hasIdProperty()) {
keyset.put(sourceEntity.getRequiredIdProperty().getName(), getId());
} else {
keyset.put(ID_FIELD, getId());
}

for (String key : sortObject.keySet()) {
Object value = BsonUtils.resolveValue(map, key);

Object value = resolveValue(key, sourceEntity);

if (value == null) {
throw new IllegalStateException(
Expand All @@ -614,12 +631,24 @@ public Map<String, Object> extractKeys(Document sortObject) {

return keyset;
}

@Nullable
private Object resolveValue(String key, @Nullable MongoPersistentEntity<?> sourceEntity) {

if (sourceEntity == null) {
return BsonUtils.resolveValue(map, key);
}
PropertyPath from = PropertyPath.from(key, sourceEntity.getTypeInformation());
PersistentPropertyPath<MongoPersistentProperty> persistentPropertyPath = entityOperations.context
.getPersistentPropertyPath(from);
return BsonUtils.resolveValue(map, persistentPropertyPath.toDotPath(p -> p.getFieldName()));
}
}

private static class SimpleMappedEntity<T extends Map<String, Object>> extends UnmappedEntity<T> {

protected SimpleMappedEntity(T map) {
super(map);
protected SimpleMappedEntity(T map, EntityOperations entityOperations) {
super(map, entityOperations);
}

@Override
Expand Down Expand Up @@ -758,10 +787,15 @@ public boolean isNew() {
}

@Override
public Map<String, Object> extractKeys(Document sortObject) {
public Map<String, Object> extractKeys(Document sortObject, Class<?> sourceType) {

Map<String, Object> keyset = new LinkedHashMap<>();
keyset.put(entity.getRequiredIdProperty().getName(), getId());
MongoPersistentEntity<?> sourceEntity = entityOperations.context.getPersistentEntity(sourceType);
if (sourceEntity != null && sourceEntity.hasIdProperty()) {
keyset.put(sourceEntity.getRequiredIdProperty().getName(), getId());
} else {
keyset.put(entity.getRequiredIdProperty().getName(), getId());
}

for (String key : sortObject.keySet()) {

Expand Down Expand Up @@ -933,6 +967,14 @@ interface TypedOperations<T> {
* @since 3.3
*/
TimeSeriesOptions mapTimeSeriesOptions(TimeSeriesOptions options);

/**
* @return the name of the id field.
* @since 4.1
*/
default String getIdKeyName() {
return ID_FIELD;
}
}

/**
Expand Down Expand Up @@ -1055,6 +1097,11 @@ private String mappedNameOrDefault(String name) {
MongoPersistentProperty persistentProperty = entity.getPersistentProperty(name);
return persistentProperty != null ? persistentProperty.getFieldName() : name;
}

@Override
public String getIdKeyName() {
return entity.getIdProperty().getName();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,8 @@ <T> Window<T> doScroll(Query query, Class<?> sourceClass, Class<T> targetClass,
Assert.notNull(sourceClass, "Entity type must not be null");
Assert.notNull(targetClass, "Target type must not be null");

ReadDocumentCallback<T> callback = new ReadDocumentCallback<>(mongoConverter, targetClass, collectionName);
EntityProjection<T, ?> projection = operations.introspectProjection(targetClass, sourceClass);
ProjectingReadCallback<?,T> callback = new ProjectingReadCallback<>(mongoConverter, projection, collectionName);
int limit = query.isLimited() ? query.getLimit() + 1 : Integer.MAX_VALUE;

if (query.hasKeyset()) {
Expand All @@ -882,7 +883,7 @@ <T> Window<T> doScroll(Query query, Class<?> sourceClass, Class<T> targetClass,
keysetPaginationQuery.fields(), sourceClass,
new QueryCursorPreparer(query, keysetPaginationQuery.sort(), limit, 0, sourceClass), callback);

return ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), result, operations);
return ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), result, sourceClass, operations);
}

List<T> result = doFind(collectionName, createDelegate(query), query.getQueryObject(), query.getFieldsObject(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,8 @@ <T> Mono<Window<T>> doScroll(Query query, Class<?> sourceClass, Class<T> targetC
Assert.notNull(sourceClass, "Entity type must not be null");
Assert.notNull(targetClass, "Target type must not be null");

EntityProjection<T, ?> projection = operations.introspectProjection(targetClass, sourceClass);
ProjectingReadCallback<?,T> callback = new ProjectingReadCallback<>(mongoConverter, projection, collectionName);
int limit = query.isLimited() ? query.getLimit() + 1 : Integer.MAX_VALUE;

if (query.hasKeyset()) {
Expand All @@ -857,15 +859,15 @@ <T> Mono<Window<T>> doScroll(Query query, Class<?> sourceClass, Class<T> targetC
operations.getIdPropertyName(sourceClass));

Mono<List<T>> result = doFind(collectionName, ReactiveCollectionPreparerDelegate.of(query),
keysetPaginationQuery.query(), keysetPaginationQuery.fields(), targetClass,
new QueryFindPublisherPreparer(query, keysetPaginationQuery.sort(), limit, 0, sourceClass)).collectList();
keysetPaginationQuery.query(), keysetPaginationQuery.fields(), sourceClass,
new QueryFindPublisherPreparer(query, keysetPaginationQuery.sort(), limit, 0, sourceClass), callback).collectList();

return result.map(it -> ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), it, operations));
return result.map(it -> ScrollUtils.createWindow(query.getSortObject(), query.getLimit(), it, sourceClass, operations));
}

Mono<List<T>> result = doFind(collectionName, ReactiveCollectionPreparerDelegate.of(query), query.getQueryObject(),
query.getFieldsObject(), targetClass,
new QueryFindPublisherPreparer(query, query.getSortObject(), limit, query.getSkip(), sourceClass))
query.getFieldsObject(), sourceClass,
new QueryFindPublisherPreparer(query, query.getSortObject(), limit, query.getSkip(), sourceClass), callback)
.collectList();

return result.map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,15 @@ private static String getComparator(int sortOrder, Direction direction) {
return sortOrder == 1 ? "$gt" : "$lt";
}

static <T> Window<T> createWindow(Document sortObject, int limit, List<T> result, EntityOperations operations) {
static <T> Window<T> createWindow(Document sortObject, int limit, List<T> result, Class<?> sourceType,
EntityOperations operations) {

IntFunction<KeysetScrollPosition> positionFunction = value -> {

T last = result.get(value);
Entity<T> entity = operations.forEntity(last);

Map<String, Object> keys = entity.extractKeys(sortObject);
Map<String, Object> keys = entity.extractKeys(sortObject, sourceType);
return KeysetScrollPosition.of(keys);
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver;
import org.springframework.data.mongodb.core.mapping.TimeSeries;
import org.springframework.data.mongodb.test.util.MongoTestMappingContext;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;

/**
* Unit tests for {@link EntityOperations}.
*
* @author Mark Paluch
* @author Christoph Strobl
*/
class EntityOperationsUnitTests {

Expand Down Expand Up @@ -70,7 +72,8 @@ void shouldExtractKeysFromEntity() {

WithNestedDocument object = new WithNestedDocument("foo");

Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("id", 1));
Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("id", 1),
WithNestedDocument.class);

assertThat(keys).containsEntry("id", "foo");
}
Expand All @@ -80,7 +83,7 @@ void shouldExtractKeysFromDocument() {

Document object = new Document("id", "foo");

Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("id", 1));
Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("id", 1), Document.class);

assertThat(keys).containsEntry("id", "foo");
}
Expand All @@ -90,7 +93,8 @@ void shouldExtractKeysFromNestedEntity() {

WithNestedDocument object = new WithNestedDocument("foo", new WithNestedDocument("bar"), null);

Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("nested.id", 1));
Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("nested.id", 1),
WithNestedDocument.class);

assertThat(keys).containsEntry("nested.id", "bar");
}
Expand All @@ -101,7 +105,8 @@ void shouldExtractKeysFromNestedEntityDocument() {
WithNestedDocument object = new WithNestedDocument("foo", new WithNestedDocument("bar"),
new Document("john", "doe"));

Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("document.john", 1));
Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("document.john", 1),
WithNestedDocument.class);

assertThat(keys).containsEntry("document.john", "doe");
}
Expand All @@ -111,11 +116,32 @@ void shouldExtractKeysFromNestedDocument() {

Document object = new Document("document", new Document("john", "doe"));

Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("document.john", 1));
Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("document.john", 1),
Document.class);

assertThat(keys).containsEntry("document.john", "doe");
}

@Test // GH-4308
void shouldExtractIdPropertyNameFromRawDocument() {

Document object = new Document("_id", "id-1").append("value", "val");

Map<String, Object> keys = operations.forEntity(object).extractKeys(new Document("value", 1), DomainTypeWithIdProperty.class);

assertThat(keys).containsEntry("id", "id-1");
}

@Test // GH-4308
void shouldExtractValuesFromProxy() {

ProjectionInterface source = new SpelAwareProxyProjectionFactory().createProjection(ProjectionInterface.class, new Document("_id", "id-1").append("value", "val"));

Map<String, Object> keys = operations.forEntity(source).extractKeys(new Document("value", 1), DomainTypeWithIdProperty.class);

assertThat(keys).isEqualTo(new Document("id", "id-1").append("value", "val"));
}

<T> EntityOperations.AdaptibleEntity<T> initAdaptibleEntity(T source) {
return operations.forEntity(source, conversionService);
}
Expand Down Expand Up @@ -150,4 +176,8 @@ public WithNestedDocument(String id) {
this.id = id;
}
}

interface ProjectionInterface {
String getValue();
}
}
Loading