Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,16 @@ public RootGraphImplementor<?> getEntityGraph(String graphName) {
return delegate.getEntityGraph( graphName );
}

@Override
public void attachExtension(String extensionName, Object extension) {
delegate.attachExtension( extensionName, extension );
}

@Override
public <T> T retrieveExtension(String extensionName, Class<T> extensionType) {
return delegate.retrieveExtension( extensionName, extensionType );
}

@Override
public <T> QueryImplementor<T> createQuery(CriteriaSelect<T> selectQuery) {
return delegate.createQuery( selectQuery );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,4 +621,26 @@ default boolean isStatelessSession() {

@Override
RootGraphImplementor<?> getEntityGraph(String graphName);

/**
* Allows attaching session scoped extensions to the particular session instance they are based on.
*
* @param extensionName The name of the extension serves as a "key" for its retrival from a session instance.
* @param extension The extension to attach to the current session.
*/
@Incubating
void attachExtension(String extensionName, Object extension);

/**
* Returns the extensions attached to the current session.
*
* @param extensionName The name of the extension to retrieve.
* @param extensionType The type of the extension to retrieve.
* @param <T> The type of the extension to retrieve.
* @return The extension instance attached to the current session,
* or {@code null} if there is no extension with the requested name attached to the session.
* @throws ClassCastException if the requested extension cannot be cast to the requested type.
*/
@Incubating
<T> T retrieveExtension(String extensionName, Class<T> extensionType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,16 @@ public RootGraphImplementor<?> getEntityGraph(String graphName) {
return delegate.getEntityGraph( graphName );
}

@Override
public void attachExtension(String extensionName, Object extension) {
delegate.attachExtension( extensionName, extension );
}

@Override
public <T> T retrieveExtension(String extensionName, Class<T> extensionType) {
return delegate.retrieveExtension( extensionName, extensionType );
}

@Override
public <T> List<EntityGraph<? super T>> getEntityGraphs(Class<T> entityClass) {
return delegate.getEntityGraphs( entityClass );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@
import java.io.Serial;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.TimeZone;
import java.util.UUID;
Expand Down Expand Up @@ -186,6 +188,8 @@ public abstract class AbstractSharedSessionContract implements SharedSessionCont
private transient ExceptionConverter exceptionConverter;
private transient SessionAssociationMarkers sessionAssociationMarkers;

private transient Map<String, Object> extensions;

public AbstractSharedSessionContract(SessionFactoryImpl factory, SessionCreationOptions options) {
this.factory = factory;

Expand Down Expand Up @@ -1704,6 +1708,23 @@ public SessionAssociationMarkers getSessionAssociationMarkers() {
return sessionAssociationMarkers;
}

@Override
public <T> T retrieveExtension(String extensionName, Class<T> extensionType) {
if ( extensions != null ) {
Object extension = extensions.get( extensionName );
return extension == null ? null : extensionType.cast( extension );
}
return null;
}

@Override
public void attachExtension(String extensionName, Object extension) {
if ( extensions == null ) {
extensions = new HashMap<>();
}
extensions.put( extensionName, extension );
}

@Serial
private void writeObject(ObjectOutputStream oos) throws IOException {
SESSION_LOGGER.serializingSession( getSessionIdentifier() );
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.orm.test.engine.spi;

import jakarta.persistence.Id;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.Test;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

@DomainModel(annotatedClasses = {
SessionExtensionTest.UselessEntity.class,
})
@SessionFactory
public class SessionExtensionTest {

@Test
public void smoke(SessionFactoryScope scope) {
final String extensionName = "my-extension-key";
scope.inSession( sessionImplementor -> {
sessionImplementor.attachExtension( extensionName, new Extension( 1 ) );

assertThat( sessionImplementor.retrieveExtension( extensionName, Extension.class ) )
.isNotNull()
.isEqualTo( new Extension( 1 ) );
} );

scope.inStatelessSession( sessionImplementor -> {
sessionImplementor.attachExtension( extensionName, new Extension( 1 ) );

assertThat( sessionImplementor.retrieveExtension( extensionName, Extension.class ) )
.isNotNull()
.isEqualTo( new Extension( 1 ) );
} );
}

@Test
public void cast(SessionFactoryScope scope) {
final String extensionName = "my-extension-key";
scope.inSession( sessionImplementor -> {
sessionImplementor.attachExtension( extensionName, new Extension( 1 ) );

assertThatThrownBy(
() -> sessionImplementor.retrieveExtension( extensionName, SessionExtensionTest.class ) )
.isInstanceOf( ClassCastException.class );
} );

scope.inStatelessSession( sessionImplementor -> {
sessionImplementor.attachExtension( extensionName, new Extension( 1 ) );

assertThatThrownBy(
() -> sessionImplementor.retrieveExtension( extensionName, SessionExtensionTest.class ) )
.isInstanceOf( ClassCastException.class );
} );
}

private record Extension(int number) {
}

static class UselessEntity {
@Id
Long id;
}
}