@@ -639,11 +639,16 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
639639 using PyConcreteType::PyConcreteType;
640640
641641 static void bindDerived (ClassTy &c) {
642- c.def_static (" get" , &PyVectorType::get , nb::arg (" shape" ),
642+ c.def_static (" get" , &PyVectorType::getChecked , nb::arg (" shape" ),
643643 nb::arg (" element_type" ), nb::kw_only (),
644644 nb::arg (" scalable" ) = nb::none (),
645645 nb::arg (" scalable_dims" ) = nb::none (),
646646 nb::arg (" loc" ) = nb::none (), " Create a vector type" )
647+ .def_static (" get_unchecked" , &PyVectorType::get, nb::arg (" shape" ),
648+ nb::arg (" element_type" ), nb::kw_only (),
649+ nb::arg (" scalable" ) = nb::none (),
650+ nb::arg (" scalable_dims" ) = nb::none (),
651+ nb::arg (" context" ) = nb::none (), " Create a vector type" )
647652 .def_prop_ro (
648653 " scalable" ,
649654 [](MlirType self) { return mlirVectorTypeIsScalable (self); })
@@ -658,10 +663,11 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
658663 }
659664
660665private:
661- static PyVectorType get (std::vector<int64_t > shape, PyType &elementType,
662- std::optional<nb::list> scalable,
663- std::optional<std::vector<int64_t >> scalableDims,
664- DefaultingPyLocation loc) {
666+ static PyVectorType
667+ getChecked (std::vector<int64_t > shape, PyType &elementType,
668+ std::optional<nb::list> scalable,
669+ std::optional<std::vector<int64_t >> scalableDims,
670+ DefaultingPyLocation loc) {
665671 if (scalable && scalableDims) {
666672 throw nb::value_error (" 'scalable' and 'scalable_dims' kwargs "
667673 " are mutually exclusive." );
@@ -696,6 +702,42 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
696702 throw MLIRError (" Invalid type" , errors.take ());
697703 return PyVectorType (elementType.getContext (), type);
698704 }
705+
706+ static PyVectorType get (std::vector<int64_t > shape, PyType &elementType,
707+ std::optional<nb::list> scalable,
708+ std::optional<std::vector<int64_t >> scalableDims,
709+ DefaultingPyMlirContext context) {
710+ if (scalable && scalableDims) {
711+ throw nb::value_error (" 'scalable' and 'scalable_dims' kwargs "
712+ " are mutually exclusive." );
713+ }
714+
715+ PyMlirContext::ErrorCapture errors (context->getRef ());
716+ MlirType type;
717+ if (scalable) {
718+ if (scalable->size () != shape.size ())
719+ throw nb::value_error (" Expected len(scalable) == len(shape)." );
720+
721+ SmallVector<bool > scalableDimFlags = llvm::to_vector (llvm::map_range (
722+ *scalable, [](const nb::handle &h) { return nb::cast<bool >(h); }));
723+ type = mlirVectorTypeGetScalable (shape.size (), shape.data (),
724+ scalableDimFlags.data (), elementType);
725+ } else if (scalableDims) {
726+ SmallVector<bool > scalableDimFlags (shape.size (), false );
727+ for (int64_t dim : *scalableDims) {
728+ if (static_cast <size_t >(dim) >= scalableDimFlags.size () || dim < 0 )
729+ throw nb::value_error (" Scalable dimension index out of bounds." );
730+ scalableDimFlags[dim] = true ;
731+ }
732+ type = mlirVectorTypeGetScalable (shape.size (), shape.data (),
733+ scalableDimFlags.data (), elementType);
734+ } else {
735+ type = mlirVectorTypeGet (shape.size (), shape.data (), elementType);
736+ }
737+ if (mlirTypeIsNull (type))
738+ throw MLIRError (" Invalid type" , errors.take ());
739+ return PyVectorType (elementType.getContext (), type);
740+ }
699741};
700742
701743// / Ranked Tensor Type subclass - RankedTensorType.
@@ -724,6 +766,22 @@ class PyRankedTensorType
724766 nb::arg (" shape" ), nb::arg (" element_type" ),
725767 nb::arg (" encoding" ) = nb::none (), nb::arg (" loc" ) = nb::none (),
726768 " Create a ranked tensor type" );
769+ c.def_static (
770+ " get_unchecked" ,
771+ [](std::vector<int64_t > shape, PyType &elementType,
772+ std::optional<PyAttribute> &encodingAttr,
773+ DefaultingPyMlirContext context) {
774+ PyMlirContext::ErrorCapture errors (context->getRef ());
775+ MlirType t = mlirRankedTensorTypeGet (
776+ shape.size (), shape.data (), elementType,
777+ encodingAttr ? encodingAttr->get () : mlirAttributeGetNull ());
778+ if (mlirTypeIsNull (t))
779+ throw MLIRError (" Invalid type" , errors.take ());
780+ return PyRankedTensorType (elementType.getContext (), t);
781+ },
782+ nb::arg (" shape" ), nb::arg (" element_type" ),
783+ nb::arg (" encoding" ) = nb::none (), nb::arg (" context" ) = nb::none (),
784+ " Create a ranked tensor type" );
727785 c.def_prop_ro (
728786 " encoding" ,
729787 [](PyRankedTensorType &self)
@@ -758,6 +816,17 @@ class PyUnrankedTensorType
758816 },
759817 nb::arg (" element_type" ), nb::arg (" loc" ) = nb::none (),
760818 " Create a unranked tensor type" );
819+ c.def_static (
820+ " get_unchecked" ,
821+ [](PyType &elementType, DefaultingPyMlirContext context) {
822+ PyMlirContext::ErrorCapture errors (context->getRef ());
823+ MlirType t = mlirUnrankedTensorTypeGet (elementType);
824+ if (mlirTypeIsNull (t))
825+ throw MLIRError (" Invalid type" , errors.take ());
826+ return PyUnrankedTensorType (elementType.getContext (), t);
827+ },
828+ nb::arg (" element_type" ), nb::arg (" context" ) = nb::none (),
829+ " Create a unranked tensor type" );
761830 }
762831};
763832
@@ -790,6 +859,27 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
790859 nb::arg (" shape" ), nb::arg (" element_type" ),
791860 nb::arg (" layout" ) = nb::none (), nb::arg (" memory_space" ) = nb::none (),
792861 nb::arg (" loc" ) = nb::none (), " Create a memref type" )
862+ .def_static (
863+ " get_unchecked" ,
864+ [](std::vector<int64_t > shape, PyType &elementType,
865+ PyAttribute *layout, PyAttribute *memorySpace,
866+ DefaultingPyMlirContext context) {
867+ PyMlirContext::ErrorCapture errors (context->getRef ());
868+ MlirAttribute layoutAttr =
869+ layout ? *layout : mlirAttributeGetNull ();
870+ MlirAttribute memSpaceAttr =
871+ memorySpace ? *memorySpace : mlirAttributeGetNull ();
872+ MlirType t =
873+ mlirMemRefTypeGet (elementType, shape.size (), shape.data (),
874+ layoutAttr, memSpaceAttr);
875+ if (mlirTypeIsNull (t))
876+ throw MLIRError (" Invalid type" , errors.take ());
877+ return PyMemRefType (elementType.getContext (), t);
878+ },
879+ nb::arg (" shape" ), nb::arg (" element_type" ),
880+ nb::arg (" layout" ) = nb::none (),
881+ nb::arg (" memory_space" ) = nb::none (),
882+ nb::arg (" context" ) = nb::none (), " Create a memref type" )
793883 .def_prop_ro (
794884 " layout" ,
795885 [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
@@ -858,6 +948,22 @@ class PyUnrankedMemRefType
858948 },
859949 nb::arg (" element_type" ), nb::arg (" memory_space" ).none (),
860950 nb::arg (" loc" ) = nb::none (), " Create a unranked memref type" )
951+ .def_static (
952+ " get_unchecked" ,
953+ [](PyType &elementType, PyAttribute *memorySpace,
954+ DefaultingPyMlirContext context) {
955+ PyMlirContext::ErrorCapture errors (context->getRef ());
956+ MlirAttribute memSpaceAttr = {};
957+ if (memorySpace)
958+ memSpaceAttr = *memorySpace;
959+
960+ MlirType t = mlirUnrankedMemRefTypeGet (elementType, memSpaceAttr);
961+ if (mlirTypeIsNull (t))
962+ throw MLIRError (" Invalid type" , errors.take ());
963+ return PyUnrankedMemRefType (elementType.getContext (), t);
964+ },
965+ nb::arg (" element_type" ), nb::arg (" memory_space" ).none (),
966+ nb::arg (" context" ) = nb::none (), " Create a unranked memref type" )
861967 .def_prop_ro (
862968 " memory_space" ,
863969 [](PyUnrankedMemRefType &self)
0 commit comments